thruster_rate_limit/
lib.rs1#![allow(clippy::needless_return)]
2
3use thruster::{
4 context::typed_hyper_context::TypedHyperContext, errors::ThrusterError, middleware_fn, Context,
5 ContextState, MiddlewareNext, MiddlewareResult,
6};
7
8pub mod stores;
9use stores::Store;
10
11mod rate_limiter;
12pub use rate_limiter::{Options, RateLimiter};
13
14mod utils;
15
16pub trait Configuration<S: Send> {
17 fn should_limit(&self, _context: &TypedHyperContext<S>) -> bool {
18 return true;
19 }
20 fn get_key(&self, context: &TypedHyperContext<S>) -> String {
21 if let Some(request) = context.hyper_request.as_ref() {
22 if let Some(ip) = request.ip {
23 return ip.to_string();
24 }
25 }
26
27 return "".to_string();
28 }
29}
30
31#[middleware_fn]
32pub async fn rate_limit_middleware<
33 T: Send + Sync + ContextState<RateLimiter<S>> + ContextState<Box<C>>,
34 S: 'static + Store + Send + Sync + Clone,
35 C: 'static + Configuration<T> + Sync,
36>(
37 mut context: TypedHyperContext<T>,
38 next: MiddlewareNext<TypedHyperContext<T>>,
39) -> MiddlewareResult<TypedHyperContext<T>> {
40 #[allow(clippy::borrowed_box)]
41 let configuration: &Box<_> = context.extra.get();
42
43 if !configuration.should_limit(&context) {
44 return next(context).await;
45 }
46
47 let rate_limiter: &RateLimiter<S> = context.extra.get();
48 let RateLimiter { mut store, .. } = rate_limiter.clone();
49
50 let (path, options) = match rate_limiter.matches_route(context.route()) {
51 Some(x) => x,
52 None => ("".to_string(), rate_limiter.options.clone()),
53 };
54
55 let key = format!("rate-limit:{}:{}", configuration.get_key(&context), path);
56
57 let current_count: Option<usize> = store.get(&key).await.unwrap();
58
59 let current_count = current_count.unwrap_or(0);
60 let new_count = current_count + 1;
61
62 if new_count > options.max {
63 context.status(429);
64 return Err(ThrusterError {
65 cause: None,
66 context,
67 message: format!("Rate limit exceeded, please wait {} seconds", options.per_s),
68 });
69 }
70
71 store.set(&key, new_count, options.per_s).await.unwrap();
72
73 return next(context).await;
74}