1use std::sync::Arc;
9use std::sync::atomic::AtomicU32;
10use std::sync::atomic::Ordering;
11use std::time::Duration;
12use std::time::Instant;
13
14use async_trait::async_trait;
15use parking_lot::Mutex;
16use scc::HashMap as SccHashMap;
17
18use super::CsrfTokenStore;
19use super::IdempotencyEntry;
20use super::IdempotencyStore;
21use super::JwksProvider;
22use super::RateLimitSnapshot;
23use super::RateLimitStore;
24use super::SessionStore;
25
26#[derive(Clone)]
27struct SessionEntry {
28 data: Vec<u8>,
29 expires_at: Instant,
30}
31
32#[derive(Default, Clone)]
34pub struct MemorySessionStore {
35 inner: Arc<SccHashMap<String, SessionEntry>>,
36}
37
38impl MemorySessionStore {
39 pub fn new() -> Self {
40 Self::default()
41 }
42}
43
44#[async_trait]
45impl SessionStore for MemorySessionStore {
46 async fn load(&self, id: &str) -> Option<Vec<u8>> {
47 let entry = self.inner.get_async(id).await?;
48 if entry.expires_at <= Instant::now() {
49 return None;
50 }
51 Some(entry.data.clone())
52 }
53
54 async fn store(&self, id: &str, data: Vec<u8>, ttl: Duration) {
55 let entry = SessionEntry {
56 data,
57 expires_at: Instant::now() + ttl,
58 };
59 let _ = self.inner.upsert_async(id.to_string(), entry).await;
60 }
61
62 async fn remove(&self, id: &str) -> bool {
63 self.inner.remove_async(id).await.is_some()
64 }
65
66 async fn sweep(&self) {
67 let now = Instant::now();
68 self.inner.retain_async(|_, v| v.expires_at > now).await;
69 }
70}
71
72#[derive(Clone)]
73struct Bucket {
74 available: f64,
75 capacity: u32,
76 refill_rate_per_sec: f64,
77 last_refill: Instant,
78}
79
80impl Bucket {
81 fn refill(&mut self, now: Instant) {
82 let dt = now.duration_since(self.last_refill).as_secs_f64();
89 debug_assert!(dt >= 0.0, "monotonic Instant violated: dt={dt}");
90 self.available = (self.available + dt * self.refill_rate_per_sec).min(f64::from(self.capacity));
91 self.last_refill = now;
92 }
93}
94
95#[derive(Clone)]
97pub struct MemoryRateLimitStore {
98 capacity: u32,
99 refill_rate_per_sec: f64,
100 inner: Arc<SccHashMap<String, Arc<Mutex<Bucket>>>>,
101}
102
103impl MemoryRateLimitStore {
104 pub fn new(capacity: u32, refill_per_sec: f64) -> Self {
106 Self {
107 capacity,
108 refill_rate_per_sec: refill_per_sec,
109 inner: Arc::new(SccHashMap::new()),
110 }
111 }
112}
113
114#[async_trait]
115impl RateLimitStore for MemoryRateLimitStore {
116 async fn consume(&self, key: &str, cost: u32) -> Result<RateLimitSnapshot, RateLimitSnapshot> {
117 let capacity = self.capacity;
118 let refill_rate = self.refill_rate_per_sec;
119 let mutex = {
120 let entry = self
121 .inner
122 .entry_async(key.to_string())
123 .await
124 .or_insert_with(|| {
125 Arc::new(Mutex::new(Bucket {
126 available: f64::from(capacity),
127 capacity,
128 refill_rate_per_sec: refill_rate,
129 last_refill: Instant::now(),
130 }))
131 });
132 entry.get().clone()
133 };
134 let mut bucket = mutex.lock();
135 let now = Instant::now();
136 bucket.refill(now);
137 let cost_f = f64::from(cost);
138 let allowed = bucket.available >= cost_f;
139 if allowed {
140 bucket.available -= cost_f;
141 }
142 let remaining = bucket.available.max(0.0).floor() as u32;
143 let needed = (cost_f - bucket.available).max(0.0);
144 let reset_secs = if bucket.refill_rate_per_sec > 0.0 {
145 (needed / bucket.refill_rate_per_sec).ceil() as u64
146 } else {
147 0
148 };
149 let retry_after_secs = if allowed { 0 } else { reset_secs.max(1) };
150 let snap = RateLimitSnapshot {
151 limit: bucket.capacity,
152 remaining,
153 reset_secs,
154 retry_after_secs,
155 };
156 if allowed { Ok(snap) } else { Err(snap) }
157 }
158}
159
160#[derive(Clone)]
161struct StoredIdempotency {
162 entry: IdempotencyEntry,
163 expires_at: Instant,
164}
165
166#[derive(Clone)]
168pub struct MemoryIdempotencyStore {
169 inner: Arc<SccHashMap<String, StoredIdempotency>>,
170 inflight_ttl: Duration,
171}
172
173impl Default for MemoryIdempotencyStore {
174 fn default() -> Self {
175 Self {
176 inner: Arc::new(SccHashMap::new()),
177 inflight_ttl: Duration::from_secs(300),
181 }
182 }
183}
184
185impl MemoryIdempotencyStore {
186 pub fn new() -> Self {
187 Self::default()
188 }
189
190 pub fn with_inflight_ttl(mut self, ttl: Duration) -> Self {
195 self.inflight_ttl = ttl;
196 self
197 }
198}
199
200#[async_trait]
201impl IdempotencyStore for MemoryIdempotencyStore {
202 async fn get(&self, key: &str) -> Option<IdempotencyEntry> {
203 let stored = self.inner.get_async(key).await?;
204 if stored.expires_at <= Instant::now() {
205 return None;
206 }
207 Some(stored.entry.clone())
208 }
209
210 async fn begin(&self, key: &str, payload_sig: [u8; 20]) -> IdempotencyEntry {
211 let entry = IdempotencyEntry {
212 status: 0,
213 headers: Vec::new(),
214 body: Vec::new(),
215 payload_sig,
216 completed: false,
217 };
218 let stored = StoredIdempotency {
219 entry: entry.clone(),
220 expires_at: Instant::now() + self.inflight_ttl,
221 };
222 let _ = self.inner.upsert_async(key.to_string(), stored).await;
223 entry
224 }
225
226 async fn complete(&self, key: &str, entry: IdempotencyEntry, ttl: Duration) {
227 let stored = StoredIdempotency {
228 entry,
229 expires_at: Instant::now() + ttl,
230 };
231 let _ = self.inner.upsert_async(key.to_string(), stored).await;
232 }
233
234 async fn remove(&self, key: &str) {
235 let _ = self.inner.remove_async(key).await;
236 }
237}
238
239#[derive(Default, Clone)]
241pub struct StaticJwksProvider {
242 by_kid: Arc<SccHashMap<String, Vec<Vec<u8>>>>,
243}
244
245impl StaticJwksProvider {
246 pub fn new() -> Self {
247 Self::default()
248 }
249
250 pub fn insert(&self, kid: impl Into<String>, key: Vec<u8>) {
259 let kid = kid.into();
260 self
261 .by_kid
262 .entry_sync(kid)
263 .and_modify(|v| v.push(key.clone()))
264 .or_insert_with(|| vec![key]);
265 }
266}
267
268#[async_trait]
269impl JwksProvider for StaticJwksProvider {
270 async fn keys_for(&self, kid: &str) -> Vec<Vec<u8>> {
271 self
272 .by_kid
273 .get_async(kid)
274 .await
275 .map(|v| v.clone())
276 .unwrap_or_default()
277 }
278}
279
280#[derive(Clone)]
281struct CsrfRecord {
282 token: String,
283 expires_at: Instant,
284 uses_left: Arc<AtomicU32>,
285}
286
287#[derive(Default, Clone)]
289pub struct MemoryCsrfTokenStore {
290 inner: Arc<SccHashMap<String, CsrfRecord>>,
291}
292
293impl MemoryCsrfTokenStore {
294 pub fn new() -> Self {
295 Self::default()
296 }
297}
298
299#[async_trait]
300impl CsrfTokenStore for MemoryCsrfTokenStore {
301 async fn issue(&self, session_id: &str, ttl: Duration) -> String {
302 let token = uuid::Uuid::new_v4().simple().to_string();
303 let record = CsrfRecord {
304 token: token.clone(),
305 expires_at: Instant::now() + ttl,
306 uses_left: Arc::new(AtomicU32::new(u32::MAX)),
307 };
308 let _ = self
309 .inner
310 .upsert_async(session_id.to_string(), record)
311 .await;
312 token
313 }
314
315 async fn validate(&self, session_id: &str, candidate: &str, single_use: bool) -> bool {
316 let record = self.inner.get_async(session_id).await;
317 let Some(record) = record else {
318 return false;
319 };
320 if record.expires_at <= Instant::now() {
321 return false;
322 }
323 if record.token != candidate {
324 return false;
325 }
326 if single_use {
327 loop {
331 let cur = record.uses_left.load(Ordering::Acquire);
332 if cur == 0 {
333 return false;
334 }
335 if record
336 .uses_left
337 .compare_exchange(cur, cur - 1, Ordering::AcqRel, Ordering::Acquire)
338 .is_ok()
339 {
340 break;
341 }
342 }
343 }
344 true
345 }
346}