1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
15#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
16#![cfg_attr(docsrs, feature(doc_cfg))]
17
18use std::borrow::Borrow;
19use std::error::Error as StdError;
20use std::fmt::{self, Debug, Formatter};
21use std::hash::Hash;
22
23use salvo_core::conn::SocketAddr;
24use salvo_core::handler::{Skipper, none_skipper};
25use salvo_core::http::{HeaderValue, Request, Response, StatusCode, StatusError};
26use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
27
28mod quota;
29pub use quota::{BasicQuota, CelledQuota, QuotaGetter};
30#[macro_use]
31mod cfg;
32
33cfg_feature! {
34 #![feature = "moka-store"]
35
36 mod moka_store;
37 pub use moka_store::MokaStore;
38}
39
40cfg_feature! {
41 #![feature = "fixed-guard"]
42
43 mod fixed_guard;
44 pub use fixed_guard::FixedGuard;
45}
46
47cfg_feature! {
48 #![feature = "sliding-guard"]
49
50 mod sliding_guard;
51 pub use sliding_guard::SlidingGuard;
52}
53
54pub trait RateIssuer: Send + Sync + 'static {
56 type Key: Hash + Eq + Send + Sync + 'static;
58 fn issue(
60 &self,
61 req: &mut Request,
62 depot: &Depot,
63 ) -> impl Future<Output = Option<Self::Key>> + Send;
64}
65impl<F, K> RateIssuer for F
66where
67 F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
68 K: Hash + Eq + Send + Sync + 'static,
69{
70 type Key = K;
71 async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
72 (self)(req, depot)
73 }
74}
75
76#[derive(Debug)]
78pub struct RemoteIpIssuer;
79impl RateIssuer for RemoteIpIssuer {
80 type Key = String;
81 async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
82 match req.remote_addr() {
83 SocketAddr::IPv4(addr) => Some(addr.ip().to_string()),
84 SocketAddr::IPv6(addr) => Some(addr.ip().to_string()),
85 _ => None,
86 }
87 }
88}
89
90pub trait RateGuard: Clone + Send + Sync + 'static {
92 type Quota: Clone + Send + Sync + 'static;
94 fn verify(&mut self, quota: &Self::Quota) -> impl Future<Output = bool> + Send;
96
97 fn remaining(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
99
100 fn reset(&self, quota: &Self::Quota) -> impl Future<Output = i64> + Send;
102
103 fn limit(&self, quota: &Self::Quota) -> impl Future<Output = usize> + Send;
105}
106
107pub trait RateStore: Send + Sync + 'static {
109 type Error: StdError;
111 type Key: Hash + Eq + Send + Clone + 'static;
113 type Guard;
115 fn load_guard<Q>(
117 &self,
118 key: &Q,
119 refer: &Self::Guard,
120 ) -> impl Future<Output = Result<Self::Guard, Self::Error>> + Send
121 where
122 Self::Key: Borrow<Q>,
123 Q: Hash + Eq + Sync;
124 fn save_guard(
126 &self,
127 key: Self::Key,
128 guard: Self::Guard,
129 ) -> impl Future<Output = Result<(), Self::Error>> + Send;
130}
131
132pub struct RateLimiter<G, S, I, Q> {
134 guard: G,
135 store: S,
136 issuer: I,
137 quota_getter: Q,
138 add_headers: bool,
139 skipper: Box<dyn Skipper>,
140}
141impl<G, S, I, Q> Debug for RateLimiter<G, S, I, Q>
142where
143 G: Debug,
144 S: Debug,
145 I: Debug,
146 Q: Debug,
147{
148 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
149 f.debug_struct("RateLimiter")
150 .field("guard", &self.guard)
151 .field("store", &self.store)
152 .field("issuer", &self.issuer)
153 .field("quota_getter", &self.quota_getter)
154 .field("add_headers", &self.add_headers)
155 .finish()
156 }
157}
158
159impl<G: RateGuard, S: RateStore, I: RateIssuer, P: QuotaGetter<I::Key>> RateLimiter<G, S, I, P> {
160 #[inline]
162 #[must_use]
163 pub fn new(guard: G, store: S, issuer: I, quota_getter: P) -> Self {
164 Self {
165 guard,
166 store,
167 issuer,
168 quota_getter,
169 add_headers: false,
170 skipper: Box::new(none_skipper),
171 }
172 }
173
174 #[inline]
176 #[must_use]
177 pub fn with_skipper(mut self, skipper: impl Skipper) -> Self {
178 self.skipper = Box::new(skipper);
179 self
180 }
181
182 #[inline]
185 #[must_use]
186 pub fn add_headers(mut self, add_headers: bool) -> Self {
187 self.add_headers = add_headers;
188 self
189 }
190}
191
192#[async_trait]
193impl<G, S, I, P> Handler for RateLimiter<G, S, I, P>
194where
195 G: RateGuard<Quota = P::Quota>,
196 S: RateStore<Key = I::Key, Guard = G>,
197 P: QuotaGetter<I::Key>,
198 I: RateIssuer,
199{
200 async fn handle(
201 &self,
202 req: &mut Request,
203 depot: &mut Depot,
204 res: &mut Response,
205 ctrl: &mut FlowCtrl,
206 ) {
207 if self.skipper.skipped(req, depot) {
208 return;
209 }
210 let Some(key) = self.issuer.issue(req, depot).await else {
211 res.render(StatusError::bad_request().brief("Invalid identifier."));
212 ctrl.skip_rest();
213 return;
214 };
215 let quota = match self.quota_getter.get(&key).await {
216 Ok(quota) => quota,
217 Err(e) => {
218 tracing::error!(error = ?e, "RateLimiter error: {}", e);
219 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
220 ctrl.skip_rest();
221 return;
222 }
223 };
224 let mut guard = match self.store.load_guard(&key, &self.guard).await {
225 Ok(guard) => guard,
226 Err(e) => {
227 tracing::error!(error = ?e, "RateLimiter error: {}", e);
228 res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
229 ctrl.skip_rest();
230 return;
231 }
232 };
233 let verified = guard.verify("a).await;
234
235 if self.add_headers {
236 res.headers_mut().insert(
237 "X-RateLimit-Limit",
238 HeaderValue::from_str(&guard.limit("a).await.to_string())
239 .expect("Invalid header value"),
240 );
241 res.headers_mut().insert(
242 "X-RateLimit-Remaining",
243 HeaderValue::from_str(&(guard.remaining("a).await).to_string())
244 .expect("Invalid header value"),
245 );
246 res.headers_mut().insert(
247 "X-RateLimit-Reset",
248 HeaderValue::from_str(&guard.reset("a).await.to_string())
249 .expect("Invalid header value"),
250 );
251 }
252 if !verified {
253 res.status_code(StatusCode::TOO_MANY_REQUESTS);
254 ctrl.skip_rest();
255 }
256 if let Err(e) = self.store.save_guard(key, guard).await {
257 tracing::error!(error = ?e, "RateLimiter save guard failed");
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use std::collections::HashMap;
265 use std::sync::LazyLock;
266
267 use salvo_core::Error;
268 use salvo_core::prelude::*;
269 use salvo_core::test::{ResponseExt, TestClient};
270
271 use super::*;
272
273 struct UserIssuer;
274 impl RateIssuer for UserIssuer {
275 type Key = String;
276 async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
277 req.query::<Self::Key>("user")
278 }
279 }
280
281 #[handler]
282 async fn limited() -> &'static str {
283 "Limited page"
284 }
285
286 #[tokio::test]
287 async fn test_fixed_dynamic_quota() {
288 static USER_QUOTAS: LazyLock<HashMap<String, BasicQuota>> = LazyLock::new(|| {
289 let mut map = HashMap::new();
290 map.insert("user1".into(), BasicQuota::per_second(1));
291 map.insert("user2".into(), BasicQuota::set_seconds(1, 5));
292 map
293 });
294
295 struct CustomQuotaGetter;
296 impl QuotaGetter<String> for CustomQuotaGetter {
297 type Quota = BasicQuota;
298 type Error = Error;
299
300 async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
301 where
302 String: Borrow<Q>,
303 Q: Hash + Eq + Sync,
304 {
305 USER_QUOTAS
306 .get(key)
307 .cloned()
308 .ok_or_else(|| Error::other("user not found"))
309 }
310 }
311 let limiter = RateLimiter::new(
312 FixedGuard::default(),
313 MokaStore::default(),
314 UserIssuer,
315 CustomQuotaGetter,
316 );
317 let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
318 let service = Service::new(router);
319
320 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
321 .send(&service)
322 .await;
323 assert_eq!(response.status_code, Some(StatusCode::OK));
324 assert_eq!(response.take_string().await.unwrap(), "Limited page");
325
326 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
327 .send(&service)
328 .await;
329 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
330
331 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
332
333 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
334 .send(&service)
335 .await;
336 assert_eq!(response.status_code, Some(StatusCode::OK));
337 assert_eq!(response.take_string().await.unwrap(), "Limited page");
338
339 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
340 .send(&service)
341 .await;
342 assert_eq!(response.status_code, Some(StatusCode::OK));
343 assert_eq!(response.take_string().await.unwrap(), "Limited page");
344
345 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
346 .send(&service)
347 .await;
348 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
349
350 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
351
352 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
353 .send(&service)
354 .await;
355 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
356
357 tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
358
359 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
360 .send(&service)
361 .await;
362 assert_eq!(response.status_code, Some(StatusCode::OK));
363 assert_eq!(response.take_string().await.unwrap(), "Limited page");
364 }
365
366 #[tokio::test]
367 async fn test_sliding_dynamic_quota() {
368 static USER_QUOTAS: LazyLock<HashMap<String, CelledQuota>> = LazyLock::new(|| {
369 let mut map = HashMap::new();
370 map.insert("user1".into(), CelledQuota::per_second(1, 1));
371 map.insert("user2".into(), CelledQuota::set_seconds(1, 1, 5));
372 map
373 });
374
375 struct CustomQuotaGetter;
376 impl QuotaGetter<String> for CustomQuotaGetter {
377 type Quota = CelledQuota;
378 type Error = Error;
379
380 async fn get<Q>(&self, key: &Q) -> Result<Self::Quota, Self::Error>
381 where
382 String: Borrow<Q>,
383 Q: Hash + Eq + Sync,
384 {
385 USER_QUOTAS
386 .get(key)
387 .cloned()
388 .ok_or_else(|| Error::other("user not found"))
389 }
390 }
391 let limiter = RateLimiter::new(
392 SlidingGuard::default(),
393 MokaStore::default(),
394 UserIssuer,
395 CustomQuotaGetter,
396 );
397 let router = Router::new().push(Router::with_path("limited").hoop(limiter).get(limited));
398 let service = Service::new(router);
399
400 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
401 .send(&service)
402 .await;
403 assert_eq!(response.status_code, Some(StatusCode::OK));
404 assert_eq!(response.take_string().await.unwrap(), "Limited page");
405
406 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
407 .send(&service)
408 .await;
409 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
410
411 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
412
413 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user1")
414 .send(&service)
415 .await;
416 assert_eq!(response.status_code, Some(StatusCode::OK));
417 assert_eq!(response.take_string().await.unwrap(), "Limited page");
418
419 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
420 .send(&service)
421 .await;
422 assert_eq!(response.status_code, Some(StatusCode::OK));
423 assert_eq!(response.take_string().await.unwrap(), "Limited page");
424
425 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
426 .send(&service)
427 .await;
428 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
429
430 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
431
432 let response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
433 .send(&service)
434 .await;
435 assert_eq!(response.status_code, Some(StatusCode::TOO_MANY_REQUESTS));
436
437 tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
438
439 let mut response = TestClient::get("http://127.0.0.1:8698/limited?user=user2")
440 .send(&service)
441 .await;
442 assert_eq!(response.status_code, Some(StatusCode::OK));
443 assert_eq!(response.take_string().await.unwrap(), "Limited page");
444 }
445}