worker_ratelimit/
lib.rs

1use std::collections::BTreeMap;
2use std::time::Duration;
3#[cfg(feature = "worker-sdk")]
4use worker::Date;
5use worker_kv::KvStore;
6
7#[derive(Debug, thiserror::Error)]
8pub enum Error {
9    #[error(transparent)]
10    Storage(#[from] worker_kv::KvError),
11    #[error(transparent)]
12    Json(#[from] serde_json::Error),
13}
14
15#[cfg(feature = "worker-sdk")]
16impl From<Error> for worker::Error {
17    fn from(error: Error) -> Self {
18        match error {
19            Error::Storage(error) => error.into(),
20            Error::Json(error) => error.into(),
21        }
22    }
23}
24
25pub type Result<T> = std::result::Result<T, Error>;
26
27#[derive(Debug, PartialEq)]
28pub enum Permit {
29    Allow(Option<Ticket>),
30    Deny,
31}
32
33pub type Stamp = BTreeMap<u64, u64>;
34
35pub async fn fetch(kv: &KvStore, key: &str) -> Result<Stamp> {
36    let stamp = if let Some(bytes) = kv.get(key).bytes().await? {
37        serde_json::from_slice::<Stamp>(&bytes)?
38    } else {
39        Stamp::default()
40    };
41    Ok(stamp)
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
45pub struct Datetime {
46    pub timestamp: u64,
47}
48
49impl Datetime {
50    pub fn from_timestamp(timestamp: u64) -> Self {
51        Self { timestamp }
52    }
53}
54
55#[cfg(feature = "worker-sdk")]
56impl From<&Date> for Datetime {
57    fn from(date: &Date) -> Self {
58        Self::from_timestamp(date.as_millis() / 1000)
59    }
60}
61
62pub struct RateLimiter {
63    pub prefix: String,
64    pub rules: BTreeMap<Duration, u64>,
65}
66
67impl RateLimiter {
68    pub fn new<I: Into<String>>(prefix: I) -> Self {
69        Self {
70            prefix: prefix.into(),
71            rules: BTreeMap::new(),
72        }
73    }
74
75    pub fn add_limit(&mut self, duration: Duration, amount: u64) {
76        self.rules.insert(duration, amount);
77    }
78
79    pub fn check_stamp<D: Into<Datetime>>(
80        &self,
81        stamp: &Stamp,
82        now: D,
83    ) -> (Permit, Option<Duration>) {
84        let now = now.into();
85
86        let mut max = None;
87        for (duration, amount) in &self.rules {
88            let start = now.timestamp - duration.as_secs();
89            let end = now.timestamp;
90
91            let mut sum = 0;
92            for (_timestamp, num) in stamp.range(start..=end) {
93                sum += num;
94            }
95
96            if sum >= *amount {
97                return (Permit::Deny, None);
98            }
99
100            max = Some(*duration);
101        }
102        (Permit::Allow(None), max)
103    }
104
105    pub async fn check_kv<D: Into<Datetime>>(
106        &self,
107        kv: &KvStore,
108        ip_addr: &str,
109        now: D,
110    ) -> Result<Permit> {
111        let now = now.into();
112
113        let key = format!("{}/{}", self.prefix, ip_addr);
114        let stamp = fetch(kv, &key).await?;
115        let (mut permit, max) = self.check_stamp(&stamp, now);
116
117        // if the action is allowed, and there was at least one rule set, issue a ticket
118        if let (Permit::Allow(ticket), Some(max)) = (&mut permit, max) {
119            *ticket = Some(Ticket {
120                key,
121                datetime: now,
122                max,
123            });
124        }
125
126        Ok(permit)
127    }
128}
129
130#[derive(Debug, PartialEq)]
131pub struct Ticket {
132    pub key: String,
133    pub datetime: Datetime,
134    pub max: Duration,
135}
136
137impl Ticket {
138    fn expire(&self, stamp: &mut Stamp) {
139        let cutoff = self.datetime.timestamp - self.max.as_secs();
140        *stamp = stamp.split_off(&cutoff);
141    }
142
143    pub async fn redeem(self, kv: &KvStore) -> Result<()> {
144        let mut stamp = fetch(kv, &self.key).await?;
145        self.expire(&mut stamp);
146
147        let counter = stamp.entry(self.datetime.timestamp).or_default();
148        *counter = counter.saturating_add(1);
149
150        let bytes = serde_json::to_vec(&stamp)?;
151        kv.put_bytes(&self.key, &bytes)?
152            .expiration_ttl(self.max.as_secs() + 1)
153            .execute()
154            .await?;
155
156        Ok(())
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn ensure_compatible_types() {
166        let _: Option<worker_kv::KvStore> = Option::<worker::kv::KvStore>::None;
167    }
168
169    #[test]
170    fn test_stamp_check_allow_empty() {
171        let mut limits = RateLimiter::new("ratelimit");
172        limits.add_limit(Duration::from_secs(5), 2);
173
174        let stamp: Stamp = [].into_iter().collect();
175        let date = Datetime::from_timestamp(1710528366);
176        let (permit, _) = limits.check_stamp(&stamp, date);
177        assert_eq!(permit, Permit::Allow(None));
178    }
179
180    #[test]
181    fn test_stamp_check_allow_some() {
182        let mut limits = RateLimiter::new("ratelimit");
183        limits.add_limit(Duration::from_secs(5), 2);
184
185        let stamp: Stamp = [(1710528362, 1)].into_iter().collect();
186        let date = Datetime::from_timestamp(1710528366);
187        let (permit, _) = limits.check_stamp(&stamp, date);
188        assert_eq!(permit, Permit::Allow(None));
189    }
190
191    #[test]
192    fn test_stamp_check_deny() {
193        let mut limits = RateLimiter::new("ratelimit");
194        limits.add_limit(Duration::from_secs(5), 2);
195
196        let stamp: Stamp = [(1710528364, 1), (1710528363, 1)].into_iter().collect();
197        let date = Datetime::from_timestamp(1710528366);
198        let (permit, _) = limits.check_stamp(&stamp, date);
199        assert_eq!(permit, Permit::Deny);
200    }
201
202    #[test]
203    fn test_expire_stamp() {
204        let mut stamp: Stamp = [
205            (1710550615, 3),
206            (1710550614, 4),
207            (1710550613, 7),
208            (1710550612, 1),
209            (1710550611, 9),
210        ]
211        .into_iter()
212        .collect();
213        let ticket = Ticket {
214            key: "abc".to_string(),
215            datetime: Datetime::from_timestamp(1710550643),
216            max: Duration::from_secs(30),
217        };
218        ticket.expire(&mut stamp);
219        let expected: Stamp = [(1710550615, 3), (1710550614, 4), (1710550613, 7)]
220            .into_iter()
221            .collect();
222        assert_eq!(stamp, expected);
223    }
224}