1use std::num::NonZeroU32;
24use std::path::Path;
25use std::sync::Arc;
26
27use dashmap::DashMap;
28use governor::clock::DefaultClock;
29use governor::state::{InMemoryState, NotKeyed};
30use governor::{Quota, RateLimiter};
31use serde::Deserialize;
32
33use crate::policy; #[derive(Debug, Clone, Deserialize)]
36pub struct Rule {
37 pub principal: String,
39 pub bucket: String,
41 pub rps: u32,
43 pub burst: u32,
45}
46
47#[derive(Clone)]
50pub struct RateLimits {
51 rules: Arc<Vec<Rule>>,
52 limiters: Arc<DashMap<(usize, String, String), Arc<KeyLimiter>>>,
56}
57
58type KeyLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
59
60impl RateLimits {
61 pub fn from_json_str(s: &str) -> Result<Self, String> {
62 let rules: Vec<Rule> =
63 serde_json::from_str(s).map_err(|e| format!("rate-limit JSON parse error: {e}"))?;
64 for r in &rules {
65 if r.rps == 0 || r.burst == 0 {
66 return Err(format!(
67 "rate-limit rule has rps=0 or burst=0 (would deny everything): {r:?}"
68 ));
69 }
70 }
71 Ok(Self {
72 rules: Arc::new(rules),
73 limiters: Arc::new(DashMap::new()),
74 })
75 }
76
77 pub fn from_path(path: &Path) -> Result<Self, String> {
78 let txt = std::fs::read_to_string(path)
79 .map_err(|e| format!("failed to read {}: {e}", path.display()))?;
80 Self::from_json_str(&txt)
81 }
82
83 pub fn check(&self, principal_id: Option<&str>, bucket: &str) -> bool {
87 let principal = principal_id.unwrap_or("");
88 for (idx, rule) in self.rules.iter().enumerate() {
89 if !glob_match(&rule.principal, principal) {
90 continue;
91 }
92 if !glob_match(&rule.bucket, bucket) {
93 continue;
94 }
95 let key = (idx, principal.to_owned(), bucket.to_owned());
96 let limiter = self
97 .limiters
98 .entry(key)
99 .or_insert_with(|| {
100 let burst = NonZeroU32::new(rule.burst).expect("burst > 0 (validated)");
101 let rps = NonZeroU32::new(rule.rps).expect("rps > 0 (validated)");
102 let quota = Quota::per_second(rps).allow_burst(burst);
103 Arc::new(RateLimiter::direct(quota))
104 })
105 .clone();
106 return limiter.check().is_ok();
107 }
108 true
110 }
111}
112
113impl std::fmt::Debug for RateLimits {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("RateLimits")
116 .field("rules", &self.rules.len())
117 .field("active_limiters", &self.limiters.len())
118 .finish()
119 }
120}
121
122pub type SharedRateLimits = Arc<RateLimits>;
123
124fn glob_match(pattern: &str, s: &str) -> bool {
128 glob_match_bytes(pattern.as_bytes(), s.as_bytes())
129}
130
131fn glob_match_bytes(p: &[u8], s: &[u8]) -> bool {
132 let mut pi = 0;
133 let mut si = 0;
134 let mut star: Option<(usize, usize)> = None;
135 while si < s.len() {
136 if pi < p.len() && (p[pi] == b'?' || p[pi] == s[si]) {
137 pi += 1;
138 si += 1;
139 } else if pi < p.len() && p[pi] == b'*' {
140 star = Some((pi, si));
141 pi += 1;
142 } else if let Some((sp, ss)) = star {
143 pi = sp + 1;
144 si = ss + 1;
145 star = Some((sp, si));
146 } else {
147 return false;
148 }
149 }
150 while pi < p.len() && p[pi] == b'*' {
151 pi += 1;
152 }
153 pi == p.len()
154}
155
156#[allow(dead_code)]
159fn _link() -> Option<policy::Effect> {
160 None
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use std::time::Duration;
167
168 fn rl(s: &str) -> RateLimits {
169 RateLimits::from_json_str(s).expect("rate-limit parse")
170 }
171
172 #[test]
173 fn parse_rejects_zero_rps_or_burst() {
174 let err = RateLimits::from_json_str(
175 r#"[{"principal": "*", "bucket": "*", "rps": 0, "burst": 10}]"#,
176 )
177 .unwrap_err();
178 assert!(err.contains("rps=0"));
179
180 let err = RateLimits::from_json_str(
181 r#"[{"principal": "*", "bucket": "*", "rps": 1, "burst": 0}]"#,
182 )
183 .unwrap_err();
184 assert!(err.contains("burst=0"));
185 }
186
187 #[test]
188 fn match_principal_and_bucket_globs() {
189 let r = rl(r#"[
190 {"principal": "AKIA*", "bucket": "tenant-a-*", "rps": 1000, "burst": 1000},
191 {"principal": "*", "bucket": "*", "rps": 1, "burst": 1}
192 ]"#);
193 assert!(r.check(Some("AKIATEST"), "tenant-a-foo"));
195 assert!(r.check(Some("anonymous"), "any"));
197 assert!(!r.check(Some("anonymous"), "any"));
199 }
200
201 #[test]
202 fn no_rule_means_no_limit() {
203 let r = rl(r#"[{"principal": "AKIATENANT", "bucket": "*", "rps": 1, "burst": 1}]"#);
204 for _ in 0..100 {
206 assert!(r.check(Some("AKIAOTHER"), "anything"));
207 }
208 }
209
210 #[test]
211 fn refill_after_wait() {
212 let r = rl(r#"[{"principal": "*", "bucket": "*", "rps": 100, "burst": 1}]"#);
213 assert!(r.check(None, "b"));
214 assert!(!r.check(None, "b"));
215 std::thread::sleep(Duration::from_millis(15)); assert!(r.check(None, "b"));
217 }
218}