1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use chrono::prelude::*;
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use std::result;
use std::sync::mpsc;
use std::time::Duration;
pub type RatelimitSender = mpsc::Sender<result::Result<RatelimitResponse, String>>;
pub struct Ratelimiter {
buckets: HashMap<String, Bucket>,
}
impl Default for Ratelimiter {
fn default() -> Self {
Self::new()
}
}
impl Ratelimiter {
pub fn new() -> Ratelimiter {
Ratelimiter {
buckets: HashMap::new(),
}
}
pub fn pass(&mut self, tx: RatelimitSender, key: &str, passes: u32, time: u32) {
let bucket = if let Some(bucket) = self.buckets.get_mut(key) {
bucket
} else {
let bucket = Bucket::new();
self.buckets.insert(key.to_string(), bucket);
self.buckets.get_mut(key).unwrap()
};
let reply = bucket.pass(passes as usize, time);
tx.send(Ok(reply)).unwrap();
}
}
struct Bucket {
passes: Vec<DateTime<Utc>>,
}
impl Bucket {
pub fn new() -> Bucket {
Bucket {
passes: Vec::new(),
}
}
pub fn pass(&mut self, passes: usize, time: u32) -> RatelimitResponse {
let now = Utc::now();
let time = chrono::Duration::milliseconds(time as i64);
let retain = now - time;
self.passes.retain(|x| *x >= retain);
if self.passes.len() >= passes {
if let Some(min) = self.passes.iter().min() {
let delay = time - (now - *min);
RatelimitResponse::Retry(delay.to_std().unwrap())
} else {
RatelimitResponse::Retry(Duration::from_millis(100))
}
} else {
let now = Utc::now();
self.passes.push(now);
RatelimitResponse::Pass
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum RatelimitResponse {
Retry(Duration),
Pass,
}