1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use chrono::prelude::*;
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use std::result;
use std::sync::mpsc;
use std::time::Duration;

pub type RatelimitSender = mpsc::Sender<result::Result<RatelimitResponse, String>>;

pub struct Ratelimiter {
    buckets: HashMap<String, Bucket>,
}

impl Default for Ratelimiter {
    fn default() -> Self {
        Self::new()
    }
}

impl Ratelimiter {
    pub fn new() -> Ratelimiter {
        Ratelimiter {
            buckets: HashMap::new(),
        }
    }

    pub fn pass(&mut self, tx: RatelimitSender, key: &str, passes: u32, time: u32) {
        let bucket = if let Some(bucket) = self.buckets.get_mut(key) {
            bucket
        } else {
            let bucket = Bucket::new();
            self.buckets.insert(key.to_string(), bucket);
            self.buckets.get_mut(key).unwrap()
        };
        let reply = bucket.pass(passes as usize, time);
        tx.send(Ok(reply)).unwrap();
    }
}

struct Bucket {
    passes: Vec<DateTime<Utc>>,
}

impl Bucket {
    pub fn new() -> Bucket {
        Bucket {
            passes: Vec::new(),
        }
    }

    pub fn pass(&mut self, passes: usize, time: u32) -> RatelimitResponse {
        let now = Utc::now();
        let time = chrono::Duration::milliseconds(time as i64);
        let retain = now - time;
        self.passes.retain(|x| *x >= retain);

        if self.passes.len() >= passes {
            if let Some(min) = self.passes.iter().min() {
                let delay = time - (now - *min);
                RatelimitResponse::Retry(delay.to_std().unwrap())
            } else {
                // This should never happen unless passes is zero
                RatelimitResponse::Retry(Duration::from_millis(100))
            }
        } else {
            let now = Utc::now();
            self.passes.push(now);
            RatelimitResponse::Pass
        }
    }
}

#[derive(Debug, Serialize, Deserialize)]
pub enum RatelimitResponse {
    Retry(Duration),
    Pass,
}