tork_core/throttle/
store.rs1use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6
7use crate::error::Result;
8use crate::router::BoxFuture;
9
10pub trait ThrottleStore: Send + Sync + 'static {
16 fn incr(&self, key: String, ttl: Duration) -> BoxFuture<'_, Result<u64>>;
19
20 fn count(&self, key: String) -> BoxFuture<'_, Result<u64>>;
23}
24
25const SWEEP_THRESHOLD: usize = 4096;
27
28struct Entry {
30 count: u64,
31 expires_at: Instant,
32}
33
34#[derive(Clone, Default)]
39pub struct MemoryThrottleStore {
40 inner: Arc<Mutex<HashMap<String, Entry>>>,
41}
42
43impl MemoryThrottleStore {
44 pub fn new() -> Self {
46 Self::default()
47 }
48}
49
50impl ThrottleStore for MemoryThrottleStore {
51 fn incr(&self, key: String, ttl: Duration) -> BoxFuture<'_, Result<u64>> {
52 Box::pin(async move {
53 let now = Instant::now();
54 let mut map = self
55 .inner
56 .lock()
57 .unwrap_or_else(|poisoned| poisoned.into_inner());
58
59 if map.len() > SWEEP_THRESHOLD {
60 map.retain(|_, entry| entry.expires_at > now);
61 }
62
63 let entry = map.entry(key).or_insert(Entry {
64 count: 0,
65 expires_at: now + ttl,
66 });
67 if entry.expires_at <= now {
69 entry.count = 0;
70 entry.expires_at = now + ttl;
71 }
72 entry.count += 1;
73 Ok(entry.count)
74 })
75 }
76
77 fn count(&self, key: String) -> BoxFuture<'_, Result<u64>> {
78 Box::pin(async move {
79 let now = Instant::now();
80 let map = self
81 .inner
82 .lock()
83 .unwrap_or_else(|poisoned| poisoned.into_inner());
84 let count = map
85 .get(&key)
86 .filter(|entry| entry.expires_at > now)
87 .map(|entry| entry.count)
88 .unwrap_or(0);
89 Ok(count)
90 })
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[tokio::test]
99 async fn counts_within_a_window_then_resets_after_it() {
100 let store = MemoryThrottleStore::new();
101 let ttl = Duration::from_millis(80);
102
103 assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 1);
104 assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 2);
105 assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 3);
106
107 tokio::time::sleep(Duration::from_millis(120)).await;
109 assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 1);
110 }
111
112 #[tokio::test]
113 async fn distinct_keys_count_independently() {
114 let store = MemoryThrottleStore::new();
115 let ttl = Duration::from_secs(60);
116 assert_eq!(store.incr("a".into(), ttl).await.unwrap(), 1);
117 assert_eq!(store.incr("b".into(), ttl).await.unwrap(), 1);
118 assert_eq!(store.incr("a".into(), ttl).await.unwrap(), 2);
119 }
120}