thruster_rate_limit/
lib.rs

1#![allow(clippy::needless_return)]
2
3use thruster::{
4    context::typed_hyper_context::TypedHyperContext, errors::ThrusterError, middleware_fn, Context,
5    ContextState, MiddlewareNext, MiddlewareResult,
6};
7
8pub mod stores;
9use stores::Store;
10
11mod rate_limiter;
12pub use rate_limiter::{Options, RateLimiter};
13
14mod utils;
15
16pub trait Configuration<S: Send> {
17    fn should_limit(&self, _context: &TypedHyperContext<S>) -> bool {
18        return true;
19    }
20    fn get_key(&self, context: &TypedHyperContext<S>) -> String {
21        if let Some(request) = context.hyper_request.as_ref() {
22            if let Some(ip) = request.ip {
23                return ip.to_string();
24            }
25        }
26
27        return "".to_string();
28    }
29}
30
31#[middleware_fn]
32pub async fn rate_limit_middleware<
33    T: Send + Sync + ContextState<RateLimiter<S>> + ContextState<Box<C>>,
34    S: 'static + Store + Send + Sync + Clone,
35    C: 'static + Configuration<T> + Sync,
36>(
37    mut context: TypedHyperContext<T>,
38    next: MiddlewareNext<TypedHyperContext<T>>,
39) -> MiddlewareResult<TypedHyperContext<T>> {
40    #[allow(clippy::borrowed_box)]
41    let configuration: &Box<_> = context.extra.get();
42
43    if !configuration.should_limit(&context) {
44        return next(context).await;
45    }
46
47    let rate_limiter: &RateLimiter<S> = context.extra.get();
48    let RateLimiter { mut store, .. } = rate_limiter.clone();
49
50    let (path, options) = match rate_limiter.matches_route(context.route()) {
51        Some(x) => x,
52        None => ("".to_string(), rate_limiter.options.clone()),
53    };
54
55    let key = format!("rate-limit:{}:{}", configuration.get_key(&context), path);
56
57    let current_count: Option<usize> = store.get(&key).await.unwrap();
58
59    let current_count = current_count.unwrap_or(0);
60    let new_count = current_count + 1;
61
62    if new_count > options.max {
63        context.status(429);
64        return Err(ThrusterError {
65            cause: None,
66            context,
67            message: format!("Rate limit exceeded, please wait {} seconds", options.per_s),
68        });
69    }
70
71    store.set(&key, new_count, options.per_s).await.unwrap();
72
73    return next(context).await;
74}