1use std::collections::BTreeMap;
2use std::time::Duration;
3#[cfg(feature = "worker-sdk")]
4use worker::Date;
5use worker_kv::KvStore;
6
7#[derive(Debug, thiserror::Error)]
8pub enum Error {
9 #[error(transparent)]
10 Storage(#[from] worker_kv::KvError),
11 #[error(transparent)]
12 Json(#[from] serde_json::Error),
13}
14
15#[cfg(feature = "worker-sdk")]
16impl From<Error> for worker::Error {
17 fn from(error: Error) -> Self {
18 match error {
19 Error::Storage(error) => error.into(),
20 Error::Json(error) => error.into(),
21 }
22 }
23}
24
25pub type Result<T> = std::result::Result<T, Error>;
26
27#[derive(Debug, PartialEq)]
28pub enum Permit {
29 Allow(Option<Ticket>),
30 Deny,
31}
32
33pub type Stamp = BTreeMap<u64, u64>;
34
35pub async fn fetch(kv: &KvStore, key: &str) -> Result<Stamp> {
36 let stamp = if let Some(bytes) = kv.get(key).bytes().await? {
37 serde_json::from_slice::<Stamp>(&bytes)?
38 } else {
39 Stamp::default()
40 };
41 Ok(stamp)
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
45pub struct Datetime {
46 pub timestamp: u64,
47}
48
49impl Datetime {
50 pub fn from_timestamp(timestamp: u64) -> Self {
51 Self { timestamp }
52 }
53}
54
55#[cfg(feature = "worker-sdk")]
56impl From<&Date> for Datetime {
57 fn from(date: &Date) -> Self {
58 Self::from_timestamp(date.as_millis() / 1000)
59 }
60}
61
62pub struct RateLimiter {
63 pub prefix: String,
64 pub rules: BTreeMap<Duration, u64>,
65}
66
67impl RateLimiter {
68 pub fn new<I: Into<String>>(prefix: I) -> Self {
69 Self {
70 prefix: prefix.into(),
71 rules: BTreeMap::new(),
72 }
73 }
74
75 pub fn add_limit(&mut self, duration: Duration, amount: u64) {
76 self.rules.insert(duration, amount);
77 }
78
79 pub fn check_stamp<D: Into<Datetime>>(
80 &self,
81 stamp: &Stamp,
82 now: D,
83 ) -> (Permit, Option<Duration>) {
84 let now = now.into();
85
86 let mut max = None;
87 for (duration, amount) in &self.rules {
88 let start = now.timestamp - duration.as_secs();
89 let end = now.timestamp;
90
91 let mut sum = 0;
92 for (_timestamp, num) in stamp.range(start..=end) {
93 sum += num;
94 }
95
96 if sum >= *amount {
97 return (Permit::Deny, None);
98 }
99
100 max = Some(*duration);
101 }
102 (Permit::Allow(None), max)
103 }
104
105 pub async fn check_kv<D: Into<Datetime>>(
106 &self,
107 kv: &KvStore,
108 ip_addr: &str,
109 now: D,
110 ) -> Result<Permit> {
111 let now = now.into();
112
113 let key = format!("{}/{}", self.prefix, ip_addr);
114 let stamp = fetch(kv, &key).await?;
115 let (mut permit, max) = self.check_stamp(&stamp, now);
116
117 if let (Permit::Allow(ticket), Some(max)) = (&mut permit, max) {
119 *ticket = Some(Ticket {
120 key,
121 datetime: now,
122 max,
123 });
124 }
125
126 Ok(permit)
127 }
128}
129
130#[derive(Debug, PartialEq)]
131pub struct Ticket {
132 pub key: String,
133 pub datetime: Datetime,
134 pub max: Duration,
135}
136
137impl Ticket {
138 fn expire(&self, stamp: &mut Stamp) {
139 let cutoff = self.datetime.timestamp - self.max.as_secs();
140 *stamp = stamp.split_off(&cutoff);
141 }
142
143 pub async fn redeem(self, kv: &KvStore) -> Result<()> {
144 let mut stamp = fetch(kv, &self.key).await?;
145 self.expire(&mut stamp);
146
147 let counter = stamp.entry(self.datetime.timestamp).or_default();
148 *counter = counter.saturating_add(1);
149
150 let bytes = serde_json::to_vec(&stamp)?;
151 kv.put_bytes(&self.key, &bytes)?
152 .expiration_ttl(self.max.as_secs() + 1)
153 .execute()
154 .await?;
155
156 Ok(())
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn ensure_compatible_types() {
166 let _: Option<worker_kv::KvStore> = Option::<worker::kv::KvStore>::None;
167 }
168
169 #[test]
170 fn test_stamp_check_allow_empty() {
171 let mut limits = RateLimiter::new("ratelimit");
172 limits.add_limit(Duration::from_secs(5), 2);
173
174 let stamp: Stamp = [].into_iter().collect();
175 let date = Datetime::from_timestamp(1710528366);
176 let (permit, _) = limits.check_stamp(&stamp, date);
177 assert_eq!(permit, Permit::Allow(None));
178 }
179
180 #[test]
181 fn test_stamp_check_allow_some() {
182 let mut limits = RateLimiter::new("ratelimit");
183 limits.add_limit(Duration::from_secs(5), 2);
184
185 let stamp: Stamp = [(1710528362, 1)].into_iter().collect();
186 let date = Datetime::from_timestamp(1710528366);
187 let (permit, _) = limits.check_stamp(&stamp, date);
188 assert_eq!(permit, Permit::Allow(None));
189 }
190
191 #[test]
192 fn test_stamp_check_deny() {
193 let mut limits = RateLimiter::new("ratelimit");
194 limits.add_limit(Duration::from_secs(5), 2);
195
196 let stamp: Stamp = [(1710528364, 1), (1710528363, 1)].into_iter().collect();
197 let date = Datetime::from_timestamp(1710528366);
198 let (permit, _) = limits.check_stamp(&stamp, date);
199 assert_eq!(permit, Permit::Deny);
200 }
201
202 #[test]
203 fn test_expire_stamp() {
204 let mut stamp: Stamp = [
205 (1710550615, 3),
206 (1710550614, 4),
207 (1710550613, 7),
208 (1710550612, 1),
209 (1710550611, 9),
210 ]
211 .into_iter()
212 .collect();
213 let ticket = Ticket {
214 key: "abc".to_string(),
215 datetime: Datetime::from_timestamp(1710550643),
216 max: Duration::from_secs(30),
217 };
218 ticket.expire(&mut stamp);
219 let expected: Stamp = [(1710550615, 3), (1710550614, 4), (1710550613, 7)]
220 .into_iter()
221 .collect();
222 assert_eq!(stamp, expected);
223 }
224}