rwf/controller/middleware/rate_limiter.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
//! Limit how many requests our clients can perform per unit of time.
//!
//! Clients that exceed those limits will have their requests rejected with HTTP `429 - Too Many`.
//! The rate limiting algorithm is nothing fancy: it counts the number of requests, and resets the count
//! every configured amount of time.
//!
//! Clients are bucketed per IP. The rate limiter supports proxies, so if `X-Forwarded-For` header is included, that IP
//! will be used instead. Each response has the `X-Rwf-Request-Rate` header set with the current requests per unit of time,
//! which could help clients self-throttle their request rate.
use std::collections::HashMap;
use std::net::IpAddr;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use super::{
super::{Error, Request, Response},
Middleware, Outcome,
};
use async_trait::async_trait;
#[derive(Default, Debug)]
struct State {
buckets: HashMap<IpAddr, Counter>,
}
#[derive(Debug)]
struct Counter {
counter: u64,
rate: f32,
last_reset: Instant,
}
impl Counter {
fn new(last_reset: Instant) -> Self {
Self {
counter: 0,
rate: 0.,
last_reset,
}
}
}
enum Frequency {
Minute(u64),
Second(u64),
Hour(u64),
Day(u64),
}
impl Frequency {
pub fn limit(&self) -> u64 {
use Frequency::*;
match self {
Minute(limit) => *limit,
Second(limit) => *limit,
Hour(limit) => *limit,
Day(limit) => *limit,
}
}
}
/// Simple rate limiter.
pub struct RateLimiter {
frequency: Frequency,
state: Mutex<State>,
}
impl RateLimiter {
/// New rate limiter with this many requests per second.
fn new(frequency: Frequency) -> Self {
Self {
frequency,
state: Mutex::new(State::default()),
}
}
/// Create rate limiter with this limit of requests per second.
pub fn per_second(limit: u64) -> Self {
Self::new(Frequency::Second(limit))
}
/// Create rate limiter with this limit of requests per minute.
pub fn per_minute(limit: u64) -> Self {
Self::new(Frequency::Minute(limit))
}
/// Create rate limiter with this limit of requests per hour. There is no advanced warning
/// for clients that reach this limit quickly. If they spend all their requests in the first minute of the hour,
/// they will be blocked for sending any more for the remainer of the hour.
pub fn per_hour(limit: u64) -> Self {
Self::new(Frequency::Hour(limit))
}
/// Create rate limiter with this limit of requests per day. There is no advanced warning
/// for clients that reach this limit quickly. If they spend all their requests in the first hour of the day,
/// they will be blocked for sending any more for the remainer of the day.
pub fn per_day(limit: u64) -> Self {
Self::new(Frequency::Day(limit))
}
}
#[async_trait]
impl Middleware for RateLimiter {
async fn handle_request(&self, request: Request) -> Result<Outcome, Error> {
let peer = match request
.headers()
.get("x-forwarded-for")
.map(|s| crate::peer_addr(s))
{
Some(Some(peer)) => peer,
_ => request.peer().clone(),
};
// Get current time before locking mutex.
// You'd be surprised how slow this function can be.
let now = Instant::now();
let reset_duration = match self.frequency {
Frequency::Second(limit) => Duration::from_millis(1000 * limit),
Frequency::Minute(limit) => Duration::from_millis(1000 * 60 * limit),
Frequency::Hour(limit) => Duration::from_millis(1000 * 3600 * limit),
Frequency::Day(limit) => Duration::from_millis(1000 * 3600 * 24 * limit),
};
let too_many = {
let mut guard = self.state.lock();
let state = guard
.buckets
.entry(peer.ip())
.or_insert_with(|| Counter::new(now));
let duration = now.duration_since(state.last_reset);
state.counter += 1;
if duration >= reset_duration {
state.rate = state.counter as f32 / duration.as_secs_f32();
state.counter = 1;
state.last_reset = now;
}
state.counter > self.frequency.limit()
};
if too_many {
Ok(Outcome::Stop(request, Response::too_many()))
} else {
Ok(Outcome::Forward(request))
}
}
async fn handle_response(
&self,
request: &Request,
response: Response,
) -> Result<Response, Error> {
if let Some(rate) = self
.state
.lock()
.buckets
.get(&request.peer().ip())
.map(|c| c.rate)
{
Ok(response.header("x-rwf-request-rate", rate.to_string()))
} else {
Ok(response)
}
}
}