Skip to main content

tako_rs_plugins/stores/
memory.rs

1//! In-memory implementations of the [`super`] backend traits.
2//!
3//! These match the `scc::HashMap`-backed defaults that the built-in
4//! middleware shipped with before the trait split. The trait split lets users
5//! swap any of these out for Redis / Postgres / other shared backends without
6//! forking the middleware itself.
7
8use std::sync::Arc;
9use std::sync::atomic::AtomicU32;
10use std::sync::atomic::Ordering;
11use std::time::Duration;
12use std::time::Instant;
13
14use async_trait::async_trait;
15use parking_lot::Mutex;
16use scc::HashMap as SccHashMap;
17
18use super::CsrfTokenStore;
19use super::IdempotencyEntry;
20use super::IdempotencyStore;
21use super::JwksProvider;
22use super::RateLimitSnapshot;
23use super::RateLimitStore;
24use super::SessionStore;
25
26#[derive(Clone)]
27struct SessionEntry {
28  data: Vec<u8>,
29  expires_at: Instant,
30}
31
32/// In-memory session backend.
33#[derive(Default, Clone)]
34pub struct MemorySessionStore {
35  inner: Arc<SccHashMap<String, SessionEntry>>,
36}
37
38impl MemorySessionStore {
39  pub fn new() -> Self {
40    Self::default()
41  }
42}
43
44#[async_trait]
45impl SessionStore for MemorySessionStore {
46  async fn load(&self, id: &str) -> Option<Vec<u8>> {
47    let entry = self.inner.get_async(id).await?;
48    if entry.expires_at <= Instant::now() {
49      return None;
50    }
51    Some(entry.data.clone())
52  }
53
54  async fn store(&self, id: &str, data: Vec<u8>, ttl: Duration) {
55    let entry = SessionEntry {
56      data,
57      expires_at: Instant::now() + ttl,
58    };
59    let _ = self.inner.upsert_async(id.to_string(), entry).await;
60  }
61
62  async fn remove(&self, id: &str) -> bool {
63    self.inner.remove_async(id).await.is_some()
64  }
65
66  async fn sweep(&self) {
67    let now = Instant::now();
68    self.inner.retain_async(|_, v| v.expires_at > now).await;
69  }
70}
71
72#[derive(Clone)]
73struct Bucket {
74  available: f64,
75  capacity: u32,
76  refill_rate_per_sec: f64,
77  last_refill: Instant,
78}
79
80impl Bucket {
81  fn refill(&mut self, now: Instant) {
82    // `Instant::duration_since` saturates at zero for past-or-equal points,
83    // and `Instant` is monotonic on every platform we ship — `dt` cannot go
84    // negative. The previous `if dt > 0.0` branch was therefore dead in
85    // release; the only practical hit was `dt == 0.0` on the same-tick
86    // double-refill case, which a `debug_assert!` makes visible without
87    // costing a branch in release.
88    let dt = now.duration_since(self.last_refill).as_secs_f64();
89    debug_assert!(dt >= 0.0, "monotonic Instant violated: dt={dt}");
90    self.available = (self.available + dt * self.refill_rate_per_sec).min(f64::from(self.capacity));
91    self.last_refill = now;
92  }
93}
94
95/// Token-bucket in-memory rate limiter.
96#[derive(Clone)]
97pub struct MemoryRateLimitStore {
98  capacity: u32,
99  refill_rate_per_sec: f64,
100  inner: Arc<SccHashMap<String, Arc<Mutex<Bucket>>>>,
101}
102
103impl MemoryRateLimitStore {
104  /// `capacity` is the burst size; `refill_per_sec` adds tokens continuously.
105  pub fn new(capacity: u32, refill_per_sec: f64) -> Self {
106    Self {
107      capacity,
108      refill_rate_per_sec: refill_per_sec,
109      inner: Arc::new(SccHashMap::new()),
110    }
111  }
112}
113
114#[async_trait]
115impl RateLimitStore for MemoryRateLimitStore {
116  async fn consume(&self, key: &str, cost: u32) -> Result<RateLimitSnapshot, RateLimitSnapshot> {
117    let capacity = self.capacity;
118    let refill_rate = self.refill_rate_per_sec;
119    let mutex = {
120      let entry = self
121        .inner
122        .entry_async(key.to_string())
123        .await
124        .or_insert_with(|| {
125          Arc::new(Mutex::new(Bucket {
126            available: f64::from(capacity),
127            capacity,
128            refill_rate_per_sec: refill_rate,
129            last_refill: Instant::now(),
130          }))
131        });
132      entry.get().clone()
133    };
134    let mut bucket = mutex.lock();
135    let now = Instant::now();
136    bucket.refill(now);
137    let cost_f = f64::from(cost);
138    let allowed = bucket.available >= cost_f;
139    if allowed {
140      bucket.available -= cost_f;
141    }
142    let remaining = bucket.available.max(0.0).floor() as u32;
143    let needed = (cost_f - bucket.available).max(0.0);
144    let reset_secs = if bucket.refill_rate_per_sec > 0.0 {
145      (needed / bucket.refill_rate_per_sec).ceil() as u64
146    } else {
147      0
148    };
149    let retry_after_secs = if allowed { 0 } else { reset_secs.max(1) };
150    let snap = RateLimitSnapshot {
151      limit: bucket.capacity,
152      remaining,
153      reset_secs,
154      retry_after_secs,
155    };
156    if allowed { Ok(snap) } else { Err(snap) }
157  }
158}
159
160#[derive(Clone)]
161struct StoredIdempotency {
162  entry: IdempotencyEntry,
163  expires_at: Instant,
164}
165
166/// In-memory idempotency cache.
167#[derive(Clone)]
168pub struct MemoryIdempotencyStore {
169  inner: Arc<SccHashMap<String, StoredIdempotency>>,
170  inflight_ttl: Duration,
171}
172
173impl Default for MemoryIdempotencyStore {
174  fn default() -> Self {
175    Self {
176      inner: Arc::new(SccHashMap::new()),
177      // Long enough to outlive a typical synchronous handler but short
178      // enough to drop crashed in-flight entries before they leak. Override
179      // with `with_inflight_ttl` for slow uploads / long-running handlers.
180      inflight_ttl: Duration::from_secs(300),
181    }
182  }
183}
184
185impl MemoryIdempotencyStore {
186  pub fn new() -> Self {
187    Self::default()
188  }
189
190  /// Override the TTL used for in-flight `begin` entries. Default 300s.
191  /// Set this to be at least as long as the slowest handler that may
192  /// register an idempotency key — anything shorter risks an in-flight
193  /// entry expiring before the handler completes, breaking coalescing.
194  pub fn with_inflight_ttl(mut self, ttl: Duration) -> Self {
195    self.inflight_ttl = ttl;
196    self
197  }
198}
199
200#[async_trait]
201impl IdempotencyStore for MemoryIdempotencyStore {
202  async fn get(&self, key: &str) -> Option<IdempotencyEntry> {
203    let stored = self.inner.get_async(key).await?;
204    if stored.expires_at <= Instant::now() {
205      return None;
206    }
207    Some(stored.entry.clone())
208  }
209
210  async fn begin(&self, key: &str, payload_sig: [u8; 20]) -> IdempotencyEntry {
211    let entry = IdempotencyEntry {
212      status: 0,
213      headers: Vec::new(),
214      body: Vec::new(),
215      payload_sig,
216      completed: false,
217    };
218    let stored = StoredIdempotency {
219      entry: entry.clone(),
220      expires_at: Instant::now() + self.inflight_ttl,
221    };
222    let _ = self.inner.upsert_async(key.to_string(), stored).await;
223    entry
224  }
225
226  async fn complete(&self, key: &str, entry: IdempotencyEntry, ttl: Duration) {
227    let stored = StoredIdempotency {
228      entry,
229      expires_at: Instant::now() + ttl,
230    };
231    let _ = self.inner.upsert_async(key.to_string(), stored).await;
232  }
233
234  async fn remove(&self, key: &str) {
235    let _ = self.inner.remove_async(key).await;
236  }
237}
238
239/// Static-snapshot JWKS provider.
240#[derive(Default, Clone)]
241pub struct StaticJwksProvider {
242  by_kid: Arc<SccHashMap<String, Vec<Vec<u8>>>>,
243}
244
245impl StaticJwksProvider {
246  pub fn new() -> Self {
247    Self::default()
248  }
249
250  /// Adds a key under `kid`. Multiple keys per kid are supported (rotation).
251  ///
252  /// PPL-25: the previous `update_sync` + `insert_sync` fallback was racy.
253  /// Two threads both observing `update_sync == None` would each call
254  /// `insert_sync`; only the first wins, the loser's key was silently
255  /// dropped — a rotation footgun where the new key would not be findable
256  /// at verify time. `entry_sync(...).and_modify(...).or_insert_with(...)`
257  /// performs the lookup-or-insert atomically under the scc bucket lock.
258  pub fn insert(&self, kid: impl Into<String>, key: Vec<u8>) {
259    let kid = kid.into();
260    self
261      .by_kid
262      .entry_sync(kid)
263      .and_modify(|v| v.push(key.clone()))
264      .or_insert_with(|| vec![key]);
265  }
266}
267
268#[async_trait]
269impl JwksProvider for StaticJwksProvider {
270  async fn keys_for(&self, kid: &str) -> Vec<Vec<u8>> {
271    self
272      .by_kid
273      .get_async(kid)
274      .await
275      .map(|v| v.clone())
276      .unwrap_or_default()
277  }
278}
279
280#[derive(Clone)]
281struct CsrfRecord {
282  token: String,
283  expires_at: Instant,
284  uses_left: Arc<AtomicU32>,
285}
286
287/// In-memory CSRF token store.
288#[derive(Default, Clone)]
289pub struct MemoryCsrfTokenStore {
290  inner: Arc<SccHashMap<String, CsrfRecord>>,
291}
292
293impl MemoryCsrfTokenStore {
294  pub fn new() -> Self {
295    Self::default()
296  }
297}
298
299#[async_trait]
300impl CsrfTokenStore for MemoryCsrfTokenStore {
301  async fn issue(&self, session_id: &str, ttl: Duration) -> String {
302    let token = uuid::Uuid::new_v4().simple().to_string();
303    let record = CsrfRecord {
304      token: token.clone(),
305      expires_at: Instant::now() + ttl,
306      uses_left: Arc::new(AtomicU32::new(u32::MAX)),
307    };
308    let _ = self
309      .inner
310      .upsert_async(session_id.to_string(), record)
311      .await;
312    token
313  }
314
315  async fn validate(&self, session_id: &str, candidate: &str, single_use: bool) -> bool {
316    let record = self.inner.get_async(session_id).await;
317    let Some(record) = record else {
318      return false;
319    };
320    if record.expires_at <= Instant::now() {
321      return false;
322    }
323    if record.token != candidate {
324      return false;
325    }
326    if single_use {
327      // CAS decrement: `fetch_sub(1)` on a zero counter underflows to
328      // `u32::MAX`, which would silently re-arm the credential. Loop with
329      // `compare_exchange` so we only consume an actually-positive count.
330      loop {
331        let cur = record.uses_left.load(Ordering::Acquire);
332        if cur == 0 {
333          return false;
334        }
335        if record
336          .uses_left
337          .compare_exchange(cur, cur - 1, Ordering::AcqRel, Ordering::Acquire)
338          .is_ok()
339        {
340          break;
341        }
342      }
343    }
344    true
345  }
346}