1use std::collections::HashMap;
4use std::future::Future;
5use std::sync::Arc;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8use http::header::RETRY_AFTER;
9use http::HeaderValue;
10
11use crate::error::{Error, Result};
12use crate::extract::{FromRequest, RequestContext};
13use crate::response::{IntoResponse, Response};
14
15use super::key::{ByIp, ThrottleKey};
16use super::store::{MemoryThrottleStore, ThrottleStore};
17
18#[derive(Clone, Copy, Debug)]
22pub enum ThrottlePolicy {
23 Inherit,
25 Skip,
27 Named(&'static str),
29 Inline { limit: u32, window_secs: u64 },
31 Multiple(&'static [&'static str]),
34}
35
36#[derive(Clone, Copy)]
38struct Limit {
39 limit: u32,
40 window: Duration,
41}
42
43pub struct Throttle {
49 policies: HashMap<String, Limit>,
50 default: Option<String>,
51 store: Arc<dyn ThrottleStore>,
52 sliding: bool,
53}
54
55impl Throttle {
56 pub fn new() -> Self {
58 Self {
59 policies: HashMap::new(),
60 default: None,
61 store: Arc::new(MemoryThrottleStore::new()),
62 sliding: false,
63 }
64 }
65
66 pub fn policy(mut self, name: &str, limit: u32, window_secs: u64) -> Self {
68 self.policies.insert(
69 name.to_owned(),
70 Limit {
71 limit,
72 window: Duration::from_secs(window_secs.max(1)),
73 },
74 );
75 self
76 }
77
78 pub fn default(mut self, name: &str) -> Self {
81 self.default = Some(name.to_owned());
82 self
83 }
84
85 pub fn store(mut self, store: impl ThrottleStore) -> Self {
87 self.store = Arc::new(store);
88 self
89 }
90
91 pub fn sliding(mut self) -> Self {
97 self.sliding = true;
98 self
99 }
100
101 #[cfg(feature = "redis")]
103 pub fn redis(mut self, redis: &crate::Redis) -> Self {
104 self.store = Arc::new(super::redis::RedisThrottleStore::new(redis));
105 self
106 }
107}
108
109impl Default for Throttle {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115#[derive(Clone)]
117pub struct Throttler {
118 inner: Arc<Inner>,
119}
120
121struct Inner {
122 policies: HashMap<String, Limit>,
123 default: Option<(String, Limit)>,
124 store: Arc<dyn ThrottleStore>,
125 sliding: bool,
126}
127
128enum Decision {
130 Allow,
131 Deny { retry_after: u64 },
132}
133
134impl Throttler {
135 pub fn new(config: Throttle) -> Self {
137 let default = config.default.as_ref().and_then(|name| {
138 config
139 .policies
140 .get(name)
141 .map(|limit| (name.clone(), *limit))
142 });
143 Self {
144 inner: Arc::new(Inner {
145 policies: config.policies,
146 default,
147 store: config.store,
148 sliding: config.sliding,
149 }),
150 }
151 }
152
153 fn resolve(&self, policy: &ThrottlePolicy) -> Vec<(String, Limit)> {
157 match policy {
158 ThrottlePolicy::Skip => Vec::new(),
159 ThrottlePolicy::Inherit => self
160 .inner
161 .default
162 .as_ref()
163 .map(|(name, limit)| vec![(name.clone(), *limit)])
164 .unwrap_or_default(),
165 ThrottlePolicy::Inline { limit, window_secs } => vec![(
166 format!("inline:{limit}:{window_secs}"),
167 Limit {
168 limit: *limit,
169 window: Duration::from_secs((*window_secs).max(1)),
170 },
171 )],
172 ThrottlePolicy::Named(name) => self
173 .inner
174 .policies
175 .get(*name)
176 .map(|limit| vec![((*name).to_owned(), *limit)])
177 .unwrap_or_default(),
178 ThrottlePolicy::Multiple(names) => names
179 .iter()
180 .filter_map(|name| {
181 self.inner
182 .policies
183 .get(*name)
184 .map(|limit| ((*name).to_owned(), *limit))
185 })
186 .collect(),
187 }
188 }
189
190 async fn decide_one(&self, scope: &str, disc: &str, limit: Limit, key: &str) -> Decision {
192 let window_secs = limit.window.as_secs().max(1);
193 let now = unix_secs();
194 let bucket = now / window_secs;
195 let elapsed = now % window_secs;
196 let cap = u64::from(limit.limit);
197
198 if self.inner.sliding {
199 let current_key = format!("throttle:{scope}:{disc}:{key}:{bucket}");
203 let previous_key = format!("throttle:{scope}:{disc}:{key}:{}", bucket.wrapping_sub(1));
204 let current = self
205 .inner
206 .store
207 .incr(current_key, limit.window * 2)
208 .await
209 .unwrap_or(0);
210 let previous = self.inner.store.count(previous_key).await.unwrap_or(0);
211 let weight = (window_secs - elapsed) as f64 / window_secs as f64;
212 let estimate = current as f64 + previous as f64 * weight;
213 if estimate > cap as f64 {
214 return Decision::Deny {
215 retry_after: window_secs - elapsed,
216 };
217 }
218 } else {
219 let storage_key = format!("throttle:{scope}:{disc}:{key}:{bucket}");
220 let count = self
221 .inner
222 .store
223 .incr(storage_key, limit.window)
224 .await
225 .unwrap_or(0);
226 if count > cap {
227 return Decision::Deny {
228 retry_after: window_secs - elapsed,
229 };
230 }
231 }
232 Decision::Allow
233 }
234
235 pub async fn check(
239 &self,
240 ctx: &RequestContext,
241 policy: &ThrottlePolicy,
242 key: Option<String>,
243 ) -> Result<()> {
244 let scope = ctx.uri().path().to_owned();
245 match self.enforce(ctx, &scope, policy, key).await {
246 Decision::Allow => Ok(()),
247 Decision::Deny { .. } => Err(too_many()),
248 }
249 }
250
251 async fn enforce(
253 &self,
254 ctx: &RequestContext,
255 scope: &str,
256 policy: &ThrottlePolicy,
257 key: Option<String>,
258 ) -> Decision {
259 let limits = self.resolve(policy);
260 if limits.is_empty() {
261 return Decision::Allow;
262 }
263 let key = match key {
264 Some(key) => key,
265 None => match ByIp::throttle_key(ctx).await {
266 Ok(key) => key,
267 Err(_) => return Decision::Allow,
268 },
269 };
270 for (disc, limit) in &limits {
272 if let Decision::Deny { retry_after } = self.decide_one(scope, disc, *limit, &key).await
273 {
274 return Decision::Deny { retry_after };
275 }
276 }
277 Decision::Allow
278 }
279}
280
281impl FromRequest for Throttler {
282 fn from_request(ctx: &RequestContext) -> impl Future<Output = Result<Self>> + Send {
283 let resolved = ctx
284 .state()
285 .get::<Throttler>()
286 .map(|throttler| (*throttler).clone())
287 .ok_or_else(|| {
288 Error::internal("throttling is not configured; call `App::throttle(...)`")
289 });
290 async move { resolved }
291 }
292}
293
294#[doc(hidden)]
300pub async fn check_request(
301 ctx: &RequestContext,
302 scope: &'static str,
303 policy: &ThrottlePolicy,
304 key: Option<String>,
305) -> Option<Response> {
306 let throttler = ctx.state().get::<Throttler>()?;
307 match throttler.enforce(ctx, scope, policy, key).await {
308 Decision::Allow => None,
309 Decision::Deny { retry_after } => Some(deny_response(retry_after)),
310 }
311}
312
313fn deny_response(retry_after: u64) -> Response {
315 let mut response = too_many().into_response();
316 if let Ok(value) = HeaderValue::from_str(&retry_after.to_string()) {
317 response.headers_mut().insert(RETRY_AFTER, value);
318 }
319 response
320}
321
322fn too_many() -> Error {
324 Error::too_many_requests("rate limit exceeded").with_code("RATE_LIMITED")
325}
326
327fn unix_secs() -> u64 {
329 SystemTime::now()
330 .duration_since(UNIX_EPOCH)
331 .map(|d| d.as_secs())
332 .unwrap_or(0)
333}