1use std::collections::BTreeMap;
2use std::time::Duration;
3use worker::Date;
4use worker::kv::KvStore;
5
6#[derive(Debug, thiserror::Error)]
7pub enum Error {
8 #[error(transparent)]
9 Storage(#[from] worker::kv::KvError),
10 #[error(transparent)]
11 Json(#[from] serde_json::Error),
12}
13
14impl From<Error> for worker::Error {
15 fn from(error: Error) -> Self {
16 match error {
17 Error::Storage(error) => error.into(),
18 Error::Json(error) => error.into(),
19 }
20 }
21}
22
23pub type Result<T> = std::result::Result<T, Error>;
24
25#[derive(Debug, PartialEq)]
26pub enum Permit {
27 Allow(Option<Ticket>),
28 Deny,
29}
30
31pub type Stamp = BTreeMap<u64, u64>;
32
33pub async fn fetch(kv: &KvStore, key: &str) -> Result<Stamp> {
34 let stamp = if let Some(bytes) = kv.get(key).bytes().await? {
35 serde_json::from_slice::<Stamp>(&bytes)?
36 } else {
37 Stamp::default()
38 };
39 Ok(stamp)
40}
41
42#[derive(Debug, Clone, Copy, PartialEq)]
43pub struct Datetime {
44 pub timestamp: u64,
45}
46
47impl Datetime {
48 pub fn from_timestamp(timestamp: u64) -> Self {
49 Self { timestamp }
50 }
51}
52
53impl From<&Date> for Datetime {
54 fn from(date: &Date) -> Self {
55 Self::from_timestamp(date.as_millis() / 1000)
56 }
57}
58
59pub struct RateLimiter {
60 pub prefix: String,
61 pub rules: BTreeMap<Duration, u64>,
62}
63
64impl RateLimiter {
65 pub fn new<I: Into<String>>(prefix: I) -> Self {
66 Self {
67 prefix: prefix.into(),
68 rules: BTreeMap::new(),
69 }
70 }
71
72 pub fn add_limit(&mut self, duration: Duration, amount: u64) {
73 self.rules.insert(duration, amount);
74 }
75
76 pub fn check_stamp<D: Into<Datetime>>(
77 &self,
78 stamp: &Stamp,
79 now: D,
80 ) -> (Permit, Option<Duration>) {
81 let now = now.into();
82
83 let mut max = None;
84 for (duration, amount) in &self.rules {
85 let start = now.timestamp - duration.as_secs();
86 let end = now.timestamp;
87
88 let mut sum = 0;
89 for (_timestamp, num) in stamp.range(start..=end) {
90 sum += num;
91 }
92
93 if sum >= *amount {
94 return (Permit::Deny, None);
95 }
96
97 max = Some(*duration);
98 }
99 (Permit::Allow(None), max)
100 }
101
102 pub async fn check_kv<D: Into<Datetime>>(
103 &self,
104 kv: &KvStore,
105 ip_addr: &str,
106 now: D,
107 ) -> Result<Permit> {
108 let now = now.into();
109
110 let key = format!("{}/{}", self.prefix, ip_addr);
111 let stamp = fetch(kv, &key).await?;
112 let (mut permit, max) = self.check_stamp(&stamp, now);
113
114 if let (Permit::Allow(ticket), Some(max)) = (&mut permit, max) {
116 *ticket = Some(Ticket {
117 key,
118 datetime: now,
119 max,
120 });
121 }
122
123 Ok(permit)
124 }
125}
126
127#[derive(Debug, PartialEq)]
128pub struct Ticket {
129 pub key: String,
130 pub datetime: Datetime,
131 pub max: Duration,
132}
133
134impl Ticket {
135 fn expire(&self, stamp: &mut Stamp) {
136 let cutoff = self.datetime.timestamp - self.max.as_secs();
137 *stamp = stamp.split_off(&cutoff);
138 }
139
140 pub async fn redeem(self, kv: &KvStore) -> Result<()> {
141 let mut stamp = fetch(kv, &self.key).await?;
142 self.expire(&mut stamp);
143
144 let counter = stamp.entry(self.datetime.timestamp).or_default();
145 *counter = counter.saturating_add(1);
146
147 let bytes = serde_json::to_vec(&stamp)?;
148 kv.put_bytes(&self.key, &bytes)?
149 .expiration_ttl(self.max.as_secs() + 1)
150 .execute()
151 .await?;
152
153 Ok(())
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_stamp_check_allow_empty() {
163 let mut limits = RateLimiter::new("ratelimit");
164 limits.add_limit(Duration::from_secs(5), 2);
165
166 let stamp: Stamp = [].into_iter().collect();
167 let date = Datetime::from_timestamp(1710528366);
168 let (permit, _) = limits.check_stamp(&stamp, date);
169 assert_eq!(permit, Permit::Allow(None));
170 }
171
172 #[test]
173 fn test_stamp_check_allow_some() {
174 let mut limits = RateLimiter::new("ratelimit");
175 limits.add_limit(Duration::from_secs(5), 2);
176
177 let stamp: Stamp = [(1710528362, 1)].into_iter().collect();
178 let date = Datetime::from_timestamp(1710528366);
179 let (permit, _) = limits.check_stamp(&stamp, date);
180 assert_eq!(permit, Permit::Allow(None));
181 }
182
183 #[test]
184 fn test_stamp_check_deny() {
185 let mut limits = RateLimiter::new("ratelimit");
186 limits.add_limit(Duration::from_secs(5), 2);
187
188 let stamp: Stamp = [(1710528364, 1), (1710528363, 1)].into_iter().collect();
189 let date = Datetime::from_timestamp(1710528366);
190 let (permit, _) = limits.check_stamp(&stamp, date);
191 assert_eq!(permit, Permit::Deny);
192 }
193
194 #[test]
195 fn test_expire_stamp() {
196 let mut stamp: Stamp = [
197 (1710550615, 3),
198 (1710550614, 4),
199 (1710550613, 7),
200 (1710550612, 1),
201 (1710550611, 9),
202 ]
203 .into_iter()
204 .collect();
205 let ticket = Ticket {
206 key: "abc".to_string(),
207 datetime: Datetime::from_timestamp(1710550643),
208 max: Duration::from_secs(30),
209 };
210 ticket.expire(&mut stamp);
211 let expected: Stamp = [(1710550615, 3), (1710550614, 4), (1710550613, 7)]
212 .into_iter()
213 .collect();
214 assert_eq!(stamp, expected);
215 }
216}