Skip to main content

trillium_cache/
memory.rs

1//! In-memory [`CacheStorage`].
2//!
3//! [`InMemoryStorage`] is suitable for production reverse-proxy and
4//! client-side caching: byte-aware size cap, scan-resistant
5//! admission, and concurrent reads and writes on distinct keys
6//! without contention.
7//!
8//! ## Granularity
9//!
10//! Eviction is coarse: the unit is one [`CacheKey`] (method + URL),
11//! and all `Vary` variants stored under that key live and die together
12//! during eviction. In typical traffic patterns variants of the same URL
13//! are hot or cold together (a single `Accept-Encoding` is usually
14//! dominant, etc.), so the cost is bounded — at worst we keep a few
15//! cold variants resident alongside one hot variant. This is correct
16//! per RFC 9111; the only consequence is slightly less efficient use
17//! of memory than per-variant eviction would give.
18//!
19//! ## Sizing
20//!
21//! The byte cap is enforced over stored *body* bytes only (the
22//! dominant cost); headers and other metadata are not counted. The
23//! per-response cap on [`Cache::with_max_cacheable_size`] interacts
24//! independently — that one bounds how large any single response may
25//! be; the storage cap bounds total resident size across the cache.
26//!
27//! [`Cache::with_max_cacheable_size`]: crate::Cache::with_max_cacheable_size
28
29use crate::{CacheKey, CachePolicy, CacheStorage, PutHandle, StoredEntry};
30use futures_lite::{AsyncRead, AsyncWrite};
31use moka::{future::Cache, ops::compute::Op};
32use std::{
33    fmt::{self, Debug, Formatter},
34    io,
35    pin::Pin,
36    sync::Arc,
37    task::{Context, Poll},
38    time::Duration,
39};
40use trillium_http::{Body, BodySource, Headers};
41
42const DEFAULT_MAX_CAPACITY_BYTES: u64 = 256 * 1024 * 1024;
43
44// All variants stored under one CacheKey. Cheap to clone (Arc); held as the moka
45// value type so eviction operates per-CacheKey.
46type Bucket = Arc<[Variant]>;
47
48#[derive(Clone)]
49struct Variant {
50    policy: Arc<CachePolicy>,
51    body: Arc<[u8]>,
52    trailers: Option<Headers>,
53}
54
55impl Debug for Variant {
56    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
57        f.debug_struct("Variant")
58            .field("body_len", &self.body.len())
59            .field("has_trailers", &self.trailers.is_some())
60            .finish_non_exhaustive()
61    }
62}
63
64/// Bounded in-memory cache storage.
65///
66/// Defaults to a 256 MiB byte cap; override with
67/// [`with_max_capacity_bytes`][Self::with_max_capacity_bytes],
68/// [`unbounded`][Self::unbounded],
69/// [`with_time_to_idle`][Self::with_time_to_idle], and
70/// [`with_time_to_live`][Self::with_time_to_live]. Each setter
71/// discards any previously inserted entries; configure at
72/// construction, before the storage is populated or shared.
73///
74/// `Clone` is cheap — clones share the same backing storage.
75#[derive(Clone)]
76pub struct InMemoryStorage {
77    cache: Cache<CacheKey, Bucket>,
78    max_capacity_bytes: Option<u64>,
79    time_to_idle: Option<Duration>,
80    time_to_live: Option<Duration>,
81}
82
83impl Debug for InMemoryStorage {
84    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
85        f.debug_struct("InMemoryStorage")
86            .field("entry_count", &self.cache.entry_count())
87            .field("weighted_size", &self.cache.weighted_size())
88            .field("max_capacity_bytes", &self.max_capacity_bytes)
89            .field("time_to_idle", &self.time_to_idle)
90            .field("time_to_live", &self.time_to_live)
91            .finish_non_exhaustive()
92    }
93}
94
95impl Default for InMemoryStorage {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl InMemoryStorage {
102    /// Construct an in-memory storage with default settings: a
103    /// 256 MiB byte cap, no idle eviction, no TTL.
104    pub fn new() -> Self {
105        Self {
106            cache: build_cache(Some(DEFAULT_MAX_CAPACITY_BYTES), None, None),
107            max_capacity_bytes: Some(DEFAULT_MAX_CAPACITY_BYTES),
108            time_to_idle: None,
109            time_to_live: None,
110        }
111    }
112
113    /// Set the maximum total stored body size, in bytes. Entries are
114    /// evicted when inserts would exceed this cap. Defaults to
115    /// 256 MiB.
116    pub fn with_max_capacity_bytes(mut self, bytes: u64) -> Self {
117        self.max_capacity_bytes = Some(bytes);
118        self.rebuild();
119        self
120    }
121
122    /// Remove the size cap. The cache grows without bound. Useful in
123    /// tests and short-lived processes; production deployments should
124    /// prefer the default capped configuration.
125    pub fn unbounded(mut self) -> Self {
126        self.max_capacity_bytes = None;
127        self.rebuild();
128        self
129    }
130
131    /// Evict entries that have not been read in this duration. Off by
132    /// default.
133    pub fn with_time_to_idle(mut self, duration: Duration) -> Self {
134        self.time_to_idle = Some(duration);
135        self.rebuild();
136        self
137    }
138
139    /// Evict entries this duration after their last insert,
140    /// regardless of access. Off by default.
141    ///
142    /// Note: this is independent of RFC 9111 freshness — a stored
143    /// entry may be evicted by TTL while still within its
144    /// `max-age`/`s-maxage` window, or remain past it (the
145    /// [`CachePolicy`] handles freshness on read).
146    pub fn with_time_to_live(mut self, duration: Duration) -> Self {
147        self.time_to_live = Some(duration);
148        self.rebuild();
149        self
150    }
151
152    /// Approximate count of stored [`CacheKey`]s. Each key may hold
153    /// multiple `Vary` variants. Eventually consistent — call
154    /// [`run_pending_tasks`][Self::run_pending_tasks] first for a
155    /// settled value (useful in tests).
156    pub fn entry_count(&self) -> u64 {
157        self.cache.entry_count()
158    }
159
160    /// Approximate total weighted size (sum of stored body bytes
161    /// across all entries). Eventually consistent — call
162    /// [`run_pending_tasks`][Self::run_pending_tasks] first for a
163    /// settled value.
164    pub fn weighted_size(&self) -> u64 {
165        self.cache.weighted_size()
166    }
167
168    /// Flush pending eviction/insertion bookkeeping. Call before
169    /// reading [`entry_count`][Self::entry_count] or
170    /// [`weighted_size`][Self::weighted_size] when an exact value
171    /// matters.
172    pub async fn run_pending_tasks(&self) {
173        self.cache.run_pending_tasks().await;
174    }
175
176    // moka::future::Cache has no resize/set-capacity API — configuration is
177    // fixed at build time. Each setter rebuilds the backing cache.
178    fn rebuild(&mut self) {
179        self.cache = build_cache(
180            self.max_capacity_bytes,
181            self.time_to_idle,
182            self.time_to_live,
183        );
184    }
185}
186
187fn build_cache(
188    max_capacity_bytes: Option<u64>,
189    time_to_idle: Option<Duration>,
190    time_to_live: Option<Duration>,
191) -> Cache<CacheKey, Bucket> {
192    let mut builder = Cache::<CacheKey, Bucket>::builder().weigher(weigh_bucket);
193    if let Some(cap) = max_capacity_bytes {
194        builder = builder.max_capacity(cap);
195    }
196    if let Some(tti) = time_to_idle {
197        builder = builder.time_to_idle(tti);
198    }
199    if let Some(ttl) = time_to_live {
200        builder = builder.time_to_live(ttl);
201    }
202    builder.build()
203}
204
205fn weigh_bucket(_key: &CacheKey, bucket: &Bucket) -> u32 {
206    let total: u64 = bucket.iter().map(|v| v.body.len() as u64).sum();
207    u32::try_from(total).unwrap_or(u32::MAX)
208}
209
210/// In-memory [`StoredEntry`]. Cheap to clone — fields are `Arc`-shared
211/// with the backing cache.
212#[derive(Clone)]
213pub struct InMemoryEntry {
214    variant: Variant,
215    cache: Cache<CacheKey, Bucket>,
216    key: CacheKey,
217}
218
219impl Debug for InMemoryEntry {
220    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
221        f.debug_struct("InMemoryEntry")
222            .field("key", &self.key)
223            .field("variant", &self.variant)
224            .finish_non_exhaustive()
225    }
226}
227
228impl StoredEntry for InMemoryEntry {
229    fn policy(&self) -> &CachePolicy {
230        &self.variant.policy
231    }
232
233    async fn refresh_policy(&mut self, new_policy: CachePolicy) -> io::Result<()> {
234        let new_arc = Arc::new(new_policy);
235        // Update the local view first so an immediately-following policy() call sees the new
236        // value even if the cache update below has nothing to write back (e.g. the entry was
237        // already evicted).
238        self.variant.policy = Arc::clone(&new_arc);
239
240        self.cache
241            .entry(self.key.clone())
242            .and_compute_with(|maybe_entry| async move {
243                let Some(entry) = maybe_entry else {
244                    return Op::Nop;
245                };
246                let bucket = entry.into_value();
247                let mut updated = false;
248                let new_variants: Vec<Variant> = bucket
249                    .iter()
250                    .map(|v| {
251                        if !updated && v.policy.same_variant_as(&new_arc) {
252                            updated = true;
253                            Variant {
254                                policy: Arc::clone(&new_arc),
255                                body: Arc::clone(&v.body),
256                                trailers: v.trailers.clone(),
257                            }
258                        } else {
259                            v.clone()
260                        }
261                    })
262                    .collect();
263                if updated {
264                    Op::Put(Arc::from(new_variants.into_boxed_slice()))
265                } else {
266                    Op::Nop
267                }
268            })
269            .await;
270        Ok(())
271    }
272
273    async fn open(self) -> io::Result<Body> {
274        let Variant { body, trailers, .. } = self.variant;
275        let len = u64::try_from(body.len()).ok();
276        let source = ReplayBodySource {
277            body,
278            position: 0,
279            trailers,
280        };
281        Ok(Body::new_with_trailers(source, len))
282    }
283}
284
285// BodySource over a shared Arc<[u8]>. No copy on open; reads slice through the Arc.
286struct ReplayBodySource {
287    body: Arc<[u8]>,
288    position: usize,
289    trailers: Option<Headers>,
290}
291
292impl AsyncRead for ReplayBodySource {
293    fn poll_read(
294        mut self: Pin<&mut Self>,
295        _cx: &mut Context<'_>,
296        buf: &mut [u8],
297    ) -> Poll<io::Result<usize>> {
298        let remaining = self.body.len() - self.position;
299        let n = remaining.min(buf.len());
300        if n > 0 {
301            buf[..n].copy_from_slice(&self.body[self.position..self.position + n]);
302            self.position += n;
303        }
304        Poll::Ready(Ok(n))
305    }
306}
307
308impl BodySource for ReplayBodySource {
309    fn trailers(self: Pin<&mut Self>) -> Option<Headers> {
310        self.get_mut().trailers.take()
311    }
312}
313
314impl CacheStorage for InMemoryStorage {
315    type PutHandle = InMemoryPutHandle;
316    type StoredEntry = InMemoryEntry;
317
318    async fn get(&self, key: &CacheKey) -> Vec<Self::StoredEntry> {
319        let Some(bucket) = self.cache.get(key).await else {
320            return Vec::new();
321        };
322        bucket
323            .iter()
324            .map(|variant| InMemoryEntry {
325                variant: variant.clone(),
326                cache: self.cache.clone(),
327                key: key.clone(),
328            })
329            .collect()
330    }
331
332    async fn put(&self, key: CacheKey, policy: CachePolicy) -> io::Result<Self::PutHandle> {
333        Ok(InMemoryPutHandle {
334            cache: self.cache.clone(),
335            key,
336            policy,
337            buffer: Vec::new(),
338        })
339    }
340
341    async fn invalidate(&self, key: &CacheKey) {
342        self.cache.invalidate(key).await;
343    }
344}
345
346/// Streaming [`PutHandle`] for [`InMemoryStorage`].
347///
348/// Buffers writes internally; [`finalize`][Self::finalize] commits the
349/// buffered bytes and any trailers to the cache atomically. Drop
350/// without finalize discards the buffered bytes.
351#[derive(Debug)]
352pub struct InMemoryPutHandle {
353    cache: Cache<CacheKey, Bucket>,
354    key: CacheKey,
355    policy: CachePolicy,
356    buffer: Vec<u8>,
357}
358
359impl AsyncWrite for InMemoryPutHandle {
360    fn poll_write(
361        mut self: Pin<&mut Self>,
362        _cx: &mut Context<'_>,
363        buf: &[u8],
364    ) -> Poll<io::Result<usize>> {
365        self.buffer.extend_from_slice(buf);
366        Poll::Ready(Ok(buf.len()))
367    }
368
369    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
370        Poll::Ready(Ok(()))
371    }
372
373    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
374        Poll::Ready(Ok(()))
375    }
376}
377
378impl PutHandle for InMemoryPutHandle {
379    async fn finalize(self, trailers: Option<Headers>) -> io::Result<()> {
380        let Self {
381            cache,
382            key,
383            policy,
384            buffer,
385        } = self;
386        let new_variant = Variant {
387            policy: Arc::new(policy),
388            body: Arc::from(buffer.into_boxed_slice()),
389            trailers,
390        };
391
392        cache
393            .entry(key)
394            .and_upsert_with(|maybe_entry| async move {
395                let mut variants: Vec<Variant> = match maybe_entry {
396                    Some(entry) => entry.into_value().to_vec(),
397                    None => Vec::new(),
398                };
399                variants.retain(|v| !v.policy.same_variant_as(&new_variant.policy));
400                variants.push(new_variant);
401                Arc::from(variants.into_boxed_slice())
402            })
403            .await;
404        Ok(())
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use crate::test_helpers::*;
412    use futures_lite::{AsyncReadExt, AsyncWriteExt};
413    use std::time::SystemTime;
414    use trillium_client::Conn;
415    use trillium_http::{KnownHeaderName::*, Method, Status};
416    use trillium_testing::{TestResult, harness, test};
417
418    fn key() -> CacheKey {
419        CacheKey::new(Method::Get, "http://example.com/".parse().unwrap())
420    }
421
422    async fn store(storage: &InMemoryStorage, conn: &Conn, body: &[u8]) {
423        let policy = policy_from(conn, SystemTime::now(), private_cache());
424        let mut handle = storage.put(key(), policy).await.unwrap();
425        handle.write_all(body).await.unwrap();
426        handle.finalize(None).await.unwrap();
427    }
428
429    async fn read_body(entry: InMemoryEntry) -> Vec<u8> {
430        let mut body = entry.open().await.unwrap();
431        let mut buf = Vec::new();
432        body.read_to_end(&mut buf).await.unwrap();
433        buf
434    }
435
436    #[test(harness)]
437    async fn get_missing_key_returns_empty() -> TestResult {
438        let storage = InMemoryStorage::new();
439        assert!(storage.get(&key()).await.is_empty());
440        Ok(())
441    }
442
443    #[test(harness)]
444    async fn put_then_get_returns_entry() -> TestResult {
445        let storage = InMemoryStorage::new();
446        let conn = exchange(
447            Method::Get,
448            &[],
449            Status::Ok,
450            &[(CacheControl, "max-age=600")],
451        );
452        store(&storage, &conn, b"hello").await;
453        let result = storage.get(&key()).await;
454        assert_eq!(result.len(), 1);
455        assert_eq!(read_body(result[0].clone()).await, b"hello");
456        Ok(())
457    }
458
459    #[test(harness)]
460    async fn put_with_same_vary_replaces() -> TestResult {
461        let storage = InMemoryStorage::new();
462        let conn = exchange(
463            Method::Get,
464            &[(AcceptEncoding, "gzip")],
465            Status::Ok,
466            &[(CacheControl, "max-age=600"), (Vary, "Accept-Encoding")],
467        );
468        store(&storage, &conn, b"v1").await;
469        store(&storage, &conn, b"v2").await;
470        let result = storage.get(&key()).await;
471        assert_eq!(result.len(), 1);
472        assert_eq!(read_body(result[0].clone()).await, b"v2");
473        Ok(())
474    }
475
476    #[test(harness)]
477    async fn put_with_different_vary_appends() -> TestResult {
478        let storage = InMemoryStorage::new();
479        let gzip = exchange(
480            Method::Get,
481            &[(AcceptEncoding, "gzip")],
482            Status::Ok,
483            &[(CacheControl, "max-age=600"), (Vary, "Accept-Encoding")],
484        );
485        let br = exchange(
486            Method::Get,
487            &[(AcceptEncoding, "br")],
488            Status::Ok,
489            &[(CacheControl, "max-age=600"), (Vary, "Accept-Encoding")],
490        );
491        store(&storage, &gzip, b"gz").await;
492        store(&storage, &br, b"br").await;
493        let result = storage.get(&key()).await;
494        assert_eq!(result.len(), 2);
495        Ok(())
496    }
497
498    #[test(harness)]
499    async fn invalidate_removes_all_entries_for_key() -> TestResult {
500        let storage = InMemoryStorage::new();
501        let conn = exchange(
502            Method::Get,
503            &[],
504            Status::Ok,
505            &[(CacheControl, "max-age=600")],
506        );
507        store(&storage, &conn, b"x").await;
508        storage.run_pending_tasks().await;
509        assert_eq!(storage.entry_count(), 1);
510        storage.invalidate(&key()).await;
511        assert!(storage.get(&key()).await.is_empty());
512        storage.run_pending_tasks().await;
513        assert_eq!(storage.entry_count(), 0);
514        Ok(())
515    }
516
517    #[test(harness)]
518    async fn invalidate_does_not_touch_other_keys() -> TestResult {
519        let storage = InMemoryStorage::new();
520        let conn = exchange(
521            Method::Get,
522            &[],
523            Status::Ok,
524            &[(CacheControl, "max-age=600")],
525        );
526        let key_a = CacheKey::new(Method::Get, "http://a.example/".parse().unwrap());
527        let key_b = CacheKey::new(Method::Get, "http://b.example/".parse().unwrap());
528        {
529            let policy_a = policy_from(&conn, SystemTime::now(), private_cache());
530            let mut h = storage.put(key_a.clone(), policy_a).await.unwrap();
531            h.write_all(b"a").await.unwrap();
532            h.finalize(None).await.unwrap();
533        }
534        {
535            let policy_b = policy_from(&conn, SystemTime::now(), private_cache());
536            let mut h = storage.put(key_b.clone(), policy_b).await.unwrap();
537            h.write_all(b"b").await.unwrap();
538            h.finalize(None).await.unwrap();
539        }
540        storage.invalidate(&key_a).await;
541        assert!(storage.get(&key_a).await.is_empty());
542        assert_eq!(storage.get(&key_b).await.len(), 1);
543        Ok(())
544    }
545
546    #[test(harness)]
547    async fn drop_put_handle_without_finalize_discards() -> TestResult {
548        let storage = InMemoryStorage::new();
549        let conn = exchange(
550            Method::Get,
551            &[],
552            Status::Ok,
553            &[(CacheControl, "max-age=600")],
554        );
555        let policy = policy_from(&conn, SystemTime::now(), private_cache());
556        let mut handle = storage.put(key(), policy).await.unwrap();
557        handle.write_all(b"partial").await.unwrap();
558        drop(handle);
559        assert!(storage.get(&key()).await.is_empty());
560        Ok(())
561    }
562
563    #[test(harness)]
564    async fn refresh_policy_updates_storage() -> TestResult {
565        let storage = InMemoryStorage::new();
566        let conn = exchange(
567            Method::Get,
568            &[],
569            Status::Ok,
570            &[(CacheControl, "max-age=600")],
571        );
572        store(&storage, &conn, b"body").await;
573
574        let mut entries = storage.get(&key()).await;
575        let original_time = entries[0].policy().response_time;
576        let refreshed = exchange(
577            Method::Get,
578            &[],
579            Status::Ok,
580            &[(CacheControl, "max-age=1200")],
581        );
582        let new_policy = policy_from(
583            &refreshed,
584            original_time + Duration::from_secs(100),
585            private_cache(),
586        );
587        entries[0].refresh_policy(new_policy).await.unwrap();
588
589        let fresh = storage.get(&key()).await;
590        assert_eq!(fresh.len(), 1);
591        assert_ne!(fresh[0].policy().response_time, original_time);
592        Ok(())
593    }
594
595    // Size-bounded: insert past the cap and verify that the cache stays within bounds.
596    #[test(harness)]
597    async fn size_cap_evicts_old_entries() -> TestResult {
598        // Cap at 1 KiB; insert several 600-byte responses under distinct URLs.
599        let storage = InMemoryStorage::new().with_max_capacity_bytes(1024);
600        let conn = exchange(
601            Method::Get,
602            &[],
603            Status::Ok,
604            &[(CacheControl, "max-age=600")],
605        );
606        let body = vec![b'x'; 600];
607        for i in 0..10 {
608            let key = CacheKey::new(
609                Method::Get,
610                format!("http://example.com/{i}").parse().unwrap(),
611            );
612            let policy = policy_from(&conn, SystemTime::now(), private_cache());
613            let mut h = storage.put(key, policy).await.unwrap();
614            h.write_all(&body).await.unwrap();
615            h.finalize(None).await.unwrap();
616        }
617        storage.run_pending_tasks().await;
618        assert!(
619            storage.weighted_size() <= 1024,
620            "weighted size {} should be within cap of 1024",
621            storage.weighted_size()
622        );
623        Ok(())
624    }
625}