Skip to main content

tork_core/throttle/
engine.rs

1//! The throttle engine: policies, the runtime, and enforcement.
2
3use std::collections::HashMap;
4use std::future::Future;
5use std::sync::Arc;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8use http::header::RETRY_AFTER;
9use http::HeaderValue;
10
11use crate::error::{Error, Result};
12use crate::extract::{FromRequest, RequestContext};
13use crate::response::{IntoResponse, Response};
14
15use super::key::{ByIp, ThrottleKey};
16use super::store::{MemoryThrottleStore, ThrottleStore};
17
18/// The rate-limit policy attached to a route, emitted by the `throttle` attribute.
19///
20/// Construct-able in `const` context so the route macro can emit it directly.
21#[derive(Clone, Copy, Debug)]
22pub enum ThrottlePolicy {
23    /// Use the application's global default policy, if one is configured.
24    Inherit,
25    /// Skip rate limiting for this route entirely.
26    Skip,
27    /// Use a globally-defined named policy.
28    Named(&'static str),
29    /// An inline limit of `limit` requests per `window_secs` seconds.
30    Inline { limit: u32, window_secs: u64 },
31    /// Apply several named policies at once (for example a per-second and a
32    /// per-minute limit); the request is allowed only if every one allows it.
33    Multiple(&'static [&'static str]),
34}
35
36/// A resolved limit: a number of requests per window.
37#[derive(Clone, Copy)]
38struct Limit {
39    limit: u32,
40    window: Duration,
41}
42
43/// Configures rate limiting; see [`App::throttle`](crate::App::throttle).
44///
45/// Define named policies, optionally mark one as the global default (applied to
46/// every route that does not set its own), choose the store, and pick the
47/// algorithm (fixed window by default, or [`sliding`](Throttle::sliding)).
48pub struct Throttle {
49    policies: HashMap<String, Limit>,
50    default: Option<String>,
51    store: Arc<dyn ThrottleStore>,
52    sliding: bool,
53}
54
55impl Throttle {
56    /// Creates an empty configuration backed by an in-memory store.
57    pub fn new() -> Self {
58        Self {
59            policies: HashMap::new(),
60            default: None,
61            store: Arc::new(MemoryThrottleStore::new()),
62            sliding: false,
63        }
64    }
65
66    /// Defines a named policy of `limit` requests per `window_secs` seconds.
67    pub fn policy(mut self, name: &str, limit: u32, window_secs: u64) -> Self {
68        self.policies.insert(
69            name.to_owned(),
70            Limit {
71                limit,
72                window: Duration::from_secs(window_secs.max(1)),
73            },
74        );
75        self
76    }
77
78    /// Marks the named policy as the global default, applied to every route that
79    /// does not declare its own `throttle`.
80    pub fn default(mut self, name: &str) -> Self {
81        self.default = Some(name.to_owned());
82        self
83    }
84
85    /// Uses a custom counter store (the default is in-memory).
86    pub fn store(mut self, store: impl ThrottleStore) -> Self {
87        self.store = Arc::new(store);
88        self
89    }
90
91    /// Switches from a fixed window to a sliding window.
92    ///
93    /// A sliding window weights the previous window's count by how much of it
94    /// still overlaps the current moment, smoothing out the burst a fixed window
95    /// allows at its boundaries.
96    pub fn sliding(mut self) -> Self {
97        self.sliding = true;
98        self
99    }
100
101    /// Uses a Redis store sharing the given connection, for distributed limiting.
102    #[cfg(feature = "redis")]
103    pub fn redis(mut self, redis: &crate::Redis) -> Self {
104        self.store = Arc::new(super::redis::RedisThrottleStore::new(redis));
105        self
106    }
107}
108
109impl Default for Throttle {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115/// The runtime throttle engine, injectable and used by generated route code.
116#[derive(Clone)]
117pub struct Throttler {
118    inner: Arc<Inner>,
119}
120
121struct Inner {
122    policies: HashMap<String, Limit>,
123    default: Option<(String, Limit)>,
124    store: Arc<dyn ThrottleStore>,
125    sliding: bool,
126}
127
128/// The outcome of a rate-limit check.
129enum Decision {
130    Allow,
131    Deny { retry_after: u64 },
132}
133
134impl Throttler {
135    /// Builds the engine from its configuration.
136    pub fn new(config: Throttle) -> Self {
137        let default = config.default.as_ref().and_then(|name| {
138            config
139                .policies
140                .get(name)
141                .map(|limit| (name.clone(), *limit))
142        });
143        Self {
144            inner: Arc::new(Inner {
145                policies: config.policies,
146                default,
147                store: config.store,
148                sliding: config.sliding,
149            }),
150        }
151    }
152
153    /// Resolves a policy into the concrete limits to enforce, each with a stable
154    /// discriminator so different policies on one route count separately. An empty
155    /// list means the route is not limited (skipped, or inherits with no default).
156    fn resolve(&self, policy: &ThrottlePolicy) -> Vec<(String, Limit)> {
157        match policy {
158            ThrottlePolicy::Skip => Vec::new(),
159            ThrottlePolicy::Inherit => self
160                .inner
161                .default
162                .as_ref()
163                .map(|(name, limit)| vec![(name.clone(), *limit)])
164                .unwrap_or_default(),
165            ThrottlePolicy::Inline { limit, window_secs } => vec![(
166                format!("inline:{limit}:{window_secs}"),
167                Limit {
168                    limit: *limit,
169                    window: Duration::from_secs((*window_secs).max(1)),
170                },
171            )],
172            ThrottlePolicy::Named(name) => self
173                .inner
174                .policies
175                .get(*name)
176                .map(|limit| vec![((*name).to_owned(), *limit)])
177                .unwrap_or_default(),
178            ThrottlePolicy::Multiple(names) => names
179                .iter()
180                .filter_map(|name| {
181                    self.inner
182                        .policies
183                        .get(*name)
184                        .map(|limit| ((*name).to_owned(), *limit))
185                })
186                .collect(),
187        }
188    }
189
190    /// Counts a hit against one limit and decides whether it is allowed.
191    async fn decide_one(&self, scope: &str, disc: &str, limit: Limit, key: &str) -> Decision {
192        let window_secs = limit.window.as_secs().max(1);
193        let now = unix_secs();
194        let bucket = now / window_secs;
195        let elapsed = now % window_secs;
196        let cap = u64::from(limit.limit);
197
198        if self.inner.sliding {
199            // Sliding window: this window's count plus the previous window's count
200            // weighted by how much of it still overlaps now. Keep buckets for two
201            // windows so the previous one is still readable.
202            let current_key = format!("throttle:{scope}:{disc}:{key}:{bucket}");
203            let previous_key = format!("throttle:{scope}:{disc}:{key}:{}", bucket.wrapping_sub(1));
204            let current = self
205                .inner
206                .store
207                .incr(current_key, limit.window * 2)
208                .await
209                .unwrap_or(0);
210            let previous = self.inner.store.count(previous_key).await.unwrap_or(0);
211            let weight = (window_secs - elapsed) as f64 / window_secs as f64;
212            let estimate = current as f64 + previous as f64 * weight;
213            if estimate > cap as f64 {
214                return Decision::Deny {
215                    retry_after: window_secs - elapsed,
216                };
217            }
218        } else {
219            let storage_key = format!("throttle:{scope}:{disc}:{key}:{bucket}");
220            let count = self
221                .inner
222                .store
223                .incr(storage_key, limit.window)
224                .await
225                .unwrap_or(0);
226            if count > cap {
227                return Decision::Deny {
228                    retry_after: window_secs - elapsed,
229                };
230            }
231        }
232        Decision::Allow
233    }
234
235    /// Enforces a policy, returning `Err(429)` when a limit is exceeded.
236    ///
237    /// `key` is the precomputed tracker; `None` falls back to the client IP.
238    pub async fn check(
239        &self,
240        ctx: &RequestContext,
241        policy: &ThrottlePolicy,
242        key: Option<String>,
243    ) -> Result<()> {
244        let scope = ctx.uri().path().to_owned();
245        match self.enforce(ctx, &scope, policy, key).await {
246            Decision::Allow => Ok(()),
247            Decision::Deny { .. } => Err(too_many()),
248        }
249    }
250
251    /// Shared resolution: resolve the limits, compute the key, check each.
252    async fn enforce(
253        &self,
254        ctx: &RequestContext,
255        scope: &str,
256        policy: &ThrottlePolicy,
257        key: Option<String>,
258    ) -> Decision {
259        let limits = self.resolve(policy);
260        if limits.is_empty() {
261            return Decision::Allow;
262        }
263        let key = match key {
264            Some(key) => key,
265            None => match ByIp::throttle_key(ctx).await {
266                Ok(key) => key,
267                Err(_) => return Decision::Allow,
268            },
269        };
270        // Every limit must allow; the first to deny wins.
271        for (disc, limit) in &limits {
272            if let Decision::Deny { retry_after } = self.decide_one(scope, disc, *limit, &key).await
273            {
274                return Decision::Deny { retry_after };
275            }
276        }
277        Decision::Allow
278    }
279}
280
281impl FromRequest for Throttler {
282    fn from_request(ctx: &RequestContext) -> impl Future<Output = Result<Self>> + Send {
283        let resolved = ctx
284            .state()
285            .get::<Throttler>()
286            .map(|throttler| (*throttler).clone())
287            .ok_or_else(|| {
288                Error::internal("throttling is not configured; call `App::throttle(...)`")
289            });
290        async move { resolved }
291    }
292}
293
294/// Generated-code entry point: enforce a route's policy, returning a `429`
295/// response when a limit is exceeded (or `None` to proceed).
296///
297/// A no-op (returns `None`) when no [`Throttler`] is configured, so the check the
298/// route macro always emits costs only a state lookup in apps that do not throttle.
299#[doc(hidden)]
300pub async fn check_request(
301    ctx: &RequestContext,
302    scope: &'static str,
303    policy: &ThrottlePolicy,
304    key: Option<String>,
305) -> Option<Response> {
306    let throttler = ctx.state().get::<Throttler>()?;
307    match throttler.enforce(ctx, scope, policy, key).await {
308        Decision::Allow => None,
309        Decision::Deny { retry_after } => Some(deny_response(retry_after)),
310    }
311}
312
313/// Builds the `429 Too Many Requests` response with a `Retry-After` header.
314fn deny_response(retry_after: u64) -> Response {
315    let mut response = too_many().into_response();
316    if let Ok(value) = HeaderValue::from_str(&retry_after.to_string()) {
317        response.headers_mut().insert(RETRY_AFTER, value);
318    }
319    response
320}
321
322/// The standard rate-limit error.
323fn too_many() -> Error {
324    Error::too_many_requests("rate limit exceeded").with_code("RATE_LIMITED")
325}
326
327/// Seconds since the Unix epoch.
328fn unix_secs() -> u64 {
329    SystemTime::now()
330        .duration_since(UNIX_EPOCH)
331        .map(|d| d.as_secs())
332        .unwrap_or(0)
333}