Skip to main content

s4_server/
rate_limit.rs

1//! Per-(principal, bucket) token-bucket rate limiting (v0.4 #19).
2//!
3//! Operators describe the rules in JSON:
4//!
5//! ```json
6//! [
7//!   {"principal": "AKIATENANT_A", "bucket": "tenant-a-*", "rps": 100, "burst": 500},
8//!   {"principal": "*",            "bucket": "*",          "rps":  20, "burst":  60}
9//! ]
10//! ```
11//!
12//! Match precedence is **most-specific-first** by walk order — the JSON
13//! file's order is preserved, so put narrow rules above wildcards. Wildcards
14//! are simple `*` glob (any sequence) only; `?` is also accepted.
15//!
16//! On each PUT / GET / DELETE / List, the matching rule's bucket consumes
17//! one token. If the bucket is empty the request is rejected with
18//! `S3ErrorCode::SlowDown` (HTTP 503; AWS-spec response for "you're
19//! making requests faster than I can handle"). The
20//! `s4_rate_limit_throttled_total{principal,bucket}` Prometheus counter is
21//! bumped on every reject.
22
23use std::num::NonZeroU32;
24use std::path::Path;
25use std::sync::Arc;
26
27use dashmap::DashMap;
28use governor::clock::DefaultClock;
29use governor::state::{InMemoryState, NotKeyed};
30use governor::{Quota, RateLimiter};
31use serde::Deserialize;
32
33use crate::policy; // re-use the glob_match helper if exposed; otherwise inline below
34
35#[derive(Debug, Clone, Deserialize)]
36pub struct Rule {
37    /// `*` for any principal.
38    pub principal: String,
39    /// `*` for any bucket.
40    pub bucket: String,
41    /// Sustained requests per second.
42    pub rps: u32,
43    /// Initial / replenished bucket size.
44    pub burst: u32,
45}
46
47/// Compiled per-(principal, bucket) limiter pool. Rules are evaluated in
48/// the order they appear in the JSON; the first match wins.
49#[derive(Clone)]
50pub struct RateLimits {
51    rules: Arc<Vec<Rule>>,
52    /// Per-(rule index, principal, bucket) limiters. Created lazily —
53    /// the first request from a given principal/bucket pair instantiates
54    /// the limiter, subsequent requests reuse it.
55    limiters: Arc<DashMap<(usize, String, String), Arc<KeyLimiter>>>,
56}
57
58type KeyLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
59
60impl RateLimits {
61    pub fn from_json_str(s: &str) -> Result<Self, String> {
62        let rules: Vec<Rule> =
63            serde_json::from_str(s).map_err(|e| format!("rate-limit JSON parse error: {e}"))?;
64        for r in &rules {
65            if r.rps == 0 || r.burst == 0 {
66                return Err(format!(
67                    "rate-limit rule has rps=0 or burst=0 (would deny everything): {r:?}"
68                ));
69            }
70        }
71        Ok(Self {
72            rules: Arc::new(rules),
73            limiters: Arc::new(DashMap::new()),
74        })
75    }
76
77    pub fn from_path(path: &Path) -> Result<Self, String> {
78        let txt = std::fs::read_to_string(path)
79            .map_err(|e| format!("failed to read {}: {e}", path.display()))?;
80        Self::from_json_str(&txt)
81    }
82
83    /// Returns `true` if the request passes the limiter, `false` if
84    /// throttled. `principal_id` may be `None` (anonymous); rules with
85    /// `"principal": "*"` still match.
86    pub fn check(&self, principal_id: Option<&str>, bucket: &str) -> bool {
87        let principal = principal_id.unwrap_or("");
88        for (idx, rule) in self.rules.iter().enumerate() {
89            if !glob_match(&rule.principal, principal) {
90                continue;
91            }
92            if !glob_match(&rule.bucket, bucket) {
93                continue;
94            }
95            let key = (idx, principal.to_owned(), bucket.to_owned());
96            let limiter = self
97                .limiters
98                .entry(key)
99                .or_insert_with(|| {
100                    let burst = NonZeroU32::new(rule.burst).expect("burst > 0 (validated)");
101                    let rps = NonZeroU32::new(rule.rps).expect("rps > 0 (validated)");
102                    let quota = Quota::per_second(rps).allow_burst(burst);
103                    Arc::new(RateLimiter::direct(quota))
104                })
105                .clone();
106            return limiter.check().is_ok();
107        }
108        // No rule matched → no limit applies.
109        true
110    }
111}
112
113impl std::fmt::Debug for RateLimits {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("RateLimits")
116            .field("rules", &self.rules.len())
117            .field("active_limiters", &self.limiters.len())
118            .finish()
119    }
120}
121
122pub type SharedRateLimits = Arc<RateLimits>;
123
124/// Local minimal glob — same semantics as policy::glob_match but
125/// re-exposed here so we don't have to expose internals from `policy`.
126/// `*` = any sequence, `?` = any single char. Case-sensitive.
127fn glob_match(pattern: &str, s: &str) -> bool {
128    glob_match_bytes(pattern.as_bytes(), s.as_bytes())
129}
130
131fn glob_match_bytes(p: &[u8], s: &[u8]) -> bool {
132    let mut pi = 0;
133    let mut si = 0;
134    let mut star: Option<(usize, usize)> = None;
135    while si < s.len() {
136        if pi < p.len() && (p[pi] == b'?' || p[pi] == s[si]) {
137            pi += 1;
138            si += 1;
139        } else if pi < p.len() && p[pi] == b'*' {
140            star = Some((pi, si));
141            pi += 1;
142        } else if let Some((sp, ss)) = star {
143            pi = sp + 1;
144            si = ss + 1;
145            star = Some((sp, si));
146        } else {
147            return false;
148        }
149    }
150    while pi < p.len() && p[pi] == b'*' {
151        pi += 1;
152    }
153    pi == p.len()
154}
155
156// Touch a policy item to keep the import live (otherwise unused-imports fires
157// without changing visibility); we use the same matching shape on purpose.
158#[allow(dead_code)]
159fn _link() -> Option<policy::Effect> {
160    None
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use std::time::Duration;
167
168    fn rl(s: &str) -> RateLimits {
169        RateLimits::from_json_str(s).expect("rate-limit parse")
170    }
171
172    #[test]
173    fn parse_rejects_zero_rps_or_burst() {
174        let err = RateLimits::from_json_str(
175            r#"[{"principal": "*", "bucket": "*", "rps": 0, "burst": 10}]"#,
176        )
177        .unwrap_err();
178        assert!(err.contains("rps=0"));
179
180        let err = RateLimits::from_json_str(
181            r#"[{"principal": "*", "bucket": "*", "rps": 1, "burst": 0}]"#,
182        )
183        .unwrap_err();
184        assert!(err.contains("burst=0"));
185    }
186
187    #[test]
188    fn match_principal_and_bucket_globs() {
189        let r = rl(r#"[
190            {"principal": "AKIA*", "bucket": "tenant-a-*", "rps": 1000, "burst": 1000},
191            {"principal": "*",     "bucket": "*",          "rps": 1,    "burst": 1}
192        ]"#);
193        // First rule matches → high quota
194        assert!(r.check(Some("AKIATEST"), "tenant-a-foo"));
195        // Other principal falls to second rule → 1 token left after first
196        assert!(r.check(Some("anonymous"), "any"));
197        // Burst exhausted → throttle
198        assert!(!r.check(Some("anonymous"), "any"));
199    }
200
201    #[test]
202    fn no_rule_means_no_limit() {
203        let r = rl(r#"[{"principal": "AKIATENANT", "bucket": "*", "rps": 1, "burst": 1}]"#);
204        // Different principal → no rule matches → unlimited
205        for _ in 0..100 {
206            assert!(r.check(Some("AKIAOTHER"), "anything"));
207        }
208    }
209
210    #[test]
211    fn refill_after_wait() {
212        let r = rl(r#"[{"principal": "*", "bucket": "*", "rps": 100, "burst": 1}]"#);
213        assert!(r.check(None, "b"));
214        assert!(!r.check(None, "b"));
215        std::thread::sleep(Duration::from_millis(15)); // 100 rps = 1 token / 10 ms
216        assert!(r.check(None, "b"));
217    }
218}