Skip to main content

worker_ratelimit/
lib.rs

1use std::collections::BTreeMap;
2use std::time::Duration;
3use worker::Date;
4use worker::kv::KvStore;
5
6#[derive(Debug, thiserror::Error)]
7pub enum Error {
8    #[error(transparent)]
9    Storage(#[from] worker::kv::KvError),
10    #[error(transparent)]
11    Json(#[from] serde_json::Error),
12}
13
14impl From<Error> for worker::Error {
15    fn from(error: Error) -> Self {
16        match error {
17            Error::Storage(error) => error.into(),
18            Error::Json(error) => error.into(),
19        }
20    }
21}
22
23pub type Result<T> = std::result::Result<T, Error>;
24
25#[derive(Debug, PartialEq)]
26pub enum Permit {
27    Allow(Option<Ticket>),
28    Deny,
29}
30
31pub type Stamp = BTreeMap<u64, u64>;
32
33pub async fn fetch(kv: &KvStore, key: &str) -> Result<Stamp> {
34    let stamp = if let Some(bytes) = kv.get(key).bytes().await? {
35        serde_json::from_slice::<Stamp>(&bytes)?
36    } else {
37        Stamp::default()
38    };
39    Ok(stamp)
40}
41
42#[derive(Debug, Clone, Copy, PartialEq)]
43pub struct Datetime {
44    pub timestamp: u64,
45}
46
47impl Datetime {
48    pub fn from_timestamp(timestamp: u64) -> Self {
49        Self { timestamp }
50    }
51}
52
53impl From<&Date> for Datetime {
54    fn from(date: &Date) -> Self {
55        Self::from_timestamp(date.as_millis() / 1000)
56    }
57}
58
59pub struct RateLimiter {
60    pub prefix: String,
61    pub rules: BTreeMap<Duration, u64>,
62}
63
64impl RateLimiter {
65    pub fn new<I: Into<String>>(prefix: I) -> Self {
66        Self {
67            prefix: prefix.into(),
68            rules: BTreeMap::new(),
69        }
70    }
71
72    pub fn add_limit(&mut self, duration: Duration, amount: u64) {
73        self.rules.insert(duration, amount);
74    }
75
76    pub fn check_stamp<D: Into<Datetime>>(
77        &self,
78        stamp: &Stamp,
79        now: D,
80    ) -> (Permit, Option<Duration>) {
81        let now = now.into();
82
83        let mut max = None;
84        for (duration, amount) in &self.rules {
85            let start = now.timestamp - duration.as_secs();
86            let end = now.timestamp;
87
88            let mut sum = 0;
89            for (_timestamp, num) in stamp.range(start..=end) {
90                sum += num;
91            }
92
93            if sum >= *amount {
94                return (Permit::Deny, None);
95            }
96
97            max = Some(*duration);
98        }
99        (Permit::Allow(None), max)
100    }
101
102    pub async fn check_kv<D: Into<Datetime>>(
103        &self,
104        kv: &KvStore,
105        ip_addr: &str,
106        now: D,
107    ) -> Result<Permit> {
108        let now = now.into();
109
110        let key = format!("{}/{}", self.prefix, ip_addr);
111        let stamp = fetch(kv, &key).await?;
112        let (mut permit, max) = self.check_stamp(&stamp, now);
113
114        // if the action is allowed, and there was at least one rule set, issue a ticket
115        if let (Permit::Allow(ticket), Some(max)) = (&mut permit, max) {
116            *ticket = Some(Ticket {
117                key,
118                datetime: now,
119                max,
120            });
121        }
122
123        Ok(permit)
124    }
125}
126
127#[derive(Debug, PartialEq)]
128pub struct Ticket {
129    pub key: String,
130    pub datetime: Datetime,
131    pub max: Duration,
132}
133
134impl Ticket {
135    fn expire(&self, stamp: &mut Stamp) {
136        let cutoff = self.datetime.timestamp - self.max.as_secs();
137        *stamp = stamp.split_off(&cutoff);
138    }
139
140    pub async fn redeem(self, kv: &KvStore) -> Result<()> {
141        let mut stamp = fetch(kv, &self.key).await?;
142        self.expire(&mut stamp);
143
144        let counter = stamp.entry(self.datetime.timestamp).or_default();
145        *counter = counter.saturating_add(1);
146
147        let bytes = serde_json::to_vec(&stamp)?;
148        kv.put_bytes(&self.key, &bytes)?
149            .expiration_ttl(self.max.as_secs() + 1)
150            .execute()
151            .await?;
152
153        Ok(())
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_stamp_check_allow_empty() {
163        let mut limits = RateLimiter::new("ratelimit");
164        limits.add_limit(Duration::from_secs(5), 2);
165
166        let stamp: Stamp = [].into_iter().collect();
167        let date = Datetime::from_timestamp(1710528366);
168        let (permit, _) = limits.check_stamp(&stamp, date);
169        assert_eq!(permit, Permit::Allow(None));
170    }
171
172    #[test]
173    fn test_stamp_check_allow_some() {
174        let mut limits = RateLimiter::new("ratelimit");
175        limits.add_limit(Duration::from_secs(5), 2);
176
177        let stamp: Stamp = [(1710528362, 1)].into_iter().collect();
178        let date = Datetime::from_timestamp(1710528366);
179        let (permit, _) = limits.check_stamp(&stamp, date);
180        assert_eq!(permit, Permit::Allow(None));
181    }
182
183    #[test]
184    fn test_stamp_check_deny() {
185        let mut limits = RateLimiter::new("ratelimit");
186        limits.add_limit(Duration::from_secs(5), 2);
187
188        let stamp: Stamp = [(1710528364, 1), (1710528363, 1)].into_iter().collect();
189        let date = Datetime::from_timestamp(1710528366);
190        let (permit, _) = limits.check_stamp(&stamp, date);
191        assert_eq!(permit, Permit::Deny);
192    }
193
194    #[test]
195    fn test_expire_stamp() {
196        let mut stamp: Stamp = [
197            (1710550615, 3),
198            (1710550614, 4),
199            (1710550613, 7),
200            (1710550612, 1),
201            (1710550611, 9),
202        ]
203        .into_iter()
204        .collect();
205        let ticket = Ticket {
206            key: "abc".to_string(),
207            datetime: Datetime::from_timestamp(1710550643),
208            max: Duration::from_secs(30),
209        };
210        ticket.expire(&mut stamp);
211        let expected: Stamp = [(1710550615, 3), (1710550614, 4), (1710550613, 7)]
212            .into_iter()
213            .collect();
214        assert_eq!(stamp, expected);
215    }
216}