saorsa_core/
rate_limit.rs

1use parking_lot::RwLock;
2use std::collections::HashMap;
3use std::hash::Hash;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6
7#[derive(Debug, Clone)]
8pub struct EngineConfig {
9    pub window: Duration,
10    pub max_requests: u32,
11    pub burst_size: u32,
12}
13
14#[derive(Debug)]
15struct Bucket {
16    tokens: f64,
17    last_update: Instant,
18    requests_in_window: u32,
19    window_start: Instant,
20}
21
22impl Bucket {
23    fn new(initial_tokens: f64) -> Self {
24        let now = Instant::now();
25        Self {
26            tokens: initial_tokens,
27            last_update: now,
28            requests_in_window: 0,
29            window_start: now,
30        }
31    }
32
33    fn try_consume(&mut self, cfg: &EngineConfig) -> bool {
34        let now = Instant::now();
35        if now.duration_since(self.window_start) > cfg.window {
36            self.window_start = now;
37            self.requests_in_window = 0;
38        }
39        let elapsed = now.duration_since(self.last_update).as_secs_f64();
40        let refill_rate = cfg.max_requests as f64 / cfg.window.as_secs_f64();
41        self.tokens += elapsed * refill_rate;
42        self.tokens = self.tokens.min(cfg.burst_size as f64);
43        self.last_update = now;
44        if self.tokens >= 1.0 && self.requests_in_window < cfg.max_requests {
45            self.tokens -= 1.0;
46            self.requests_in_window += 1;
47            true
48        } else {
49            false
50        }
51    }
52}
53
54#[derive(Debug)]
55pub struct Engine<K: Eq + Hash + Clone + ToString> {
56    cfg: EngineConfig,
57    global: Mutex<Bucket>,
58    keyed: RwLock<HashMap<K, Bucket>>,
59}
60
61impl<K: Eq + Hash + Clone + ToString> Engine<K> {
62    pub fn new(cfg: EngineConfig) -> Self {
63        let burst_size = cfg.burst_size as f64;
64        Self {
65            cfg,
66            global: Mutex::new(Bucket::new(burst_size)),
67            keyed: RwLock::new(HashMap::new()),
68        }
69    }
70
71    pub fn try_consume_global(&self) -> bool {
72        match self.global.lock() {
73            Ok(mut guard) => guard.try_consume(&self.cfg),
74            Err(_poisoned) => {
75                // Treat poisoned mutex as a denial to maintain safety
76                // and avoid panicking in production code.
77                false
78            }
79        }
80    }
81
82    pub fn try_consume_key(&self, key: &K) -> bool {
83        let mut map = self.keyed.write();
84        let bucket = map
85            .entry(key.clone())
86            .or_insert_with(|| Bucket::new(self.cfg.burst_size as f64));
87        bucket.try_consume(&self.cfg)
88    }
89}
90
91pub type SharedEngine<K> = Arc<Engine<K>>;