Skip to main content

trillium_cache/
server.rs

1//! Server-side cache handler.
2//!
3//! [`Cache`] wires [`CacheStorage`] + [`CachePolicy`] onto a `trillium` server's handler chain.
4//!
5//! ## Position in the handler chain
6//!
7//! Place `Cache` *before* the handler whose responses you want to cache:
8//!
9//! ```ignore
10//! let app = (
11//!     Logger::new(),
12//!     trillium_cache::Cache::new(InMemoryStorage::new()),
13//!     my_app_handler,
14//! );
15//! ```
16//!
17//! ## Streaming
18//!
19//! Cacheable responses stream through the cache: as bytes arrive from the downstream handler,
20//! they are written to storage *and* forwarded to the user concurrently. Trailers from the
21//! response body propagate to both sides. Dropping the response body before EOF aborts the
22//! cache write — the partial bytes are discarded.
23//!
24//! ## Stale-while-revalidate
25//!
26//! This handler does not implement background `stale-while-revalidate`. A stale entry within
27//! its `stale-while-revalidate` window falls through to synchronous revalidation (the inner
28//! handler runs while the request is in flight). `stale-if-error` recovery *is* supported:
29//! when the downstream handler produces a 5xx and the stored entry is SIE-eligible, the
30//! cache serves the stored entry instead.
31//!
32//! For background `stale-while-revalidate`, use the client-side handler (gated on the
33//! `client` feature) on the client a [`trillium-proxy`](https://docs.rs/trillium-proxy) uses
34//! to reach its upstream.
35
36use crate::{
37    CacheKey, CacheOptions, CachePolicy, CacheStorage, StoredEntry,
38    tee::TeeingReader,
39    validation::{AfterResponse, BeforeRequest},
40};
41use std::{sync::Arc, time::SystemTime};
42use trillium::{Body, Conn, Handler, KnownHeaderName, Method};
43use url::Url;
44
45const DEFAULT_MAX_CACHEABLE_SIZE: u64 = 16 * 1024 * 1024;
46
47/// Server-side cache handler. Mount on a trillium handler chain together with a
48/// [`CacheStorage`] backend.
49#[derive(Debug)]
50pub struct Cache<S: CacheStorage> {
51    storage: Arc<S>,
52    options: CacheOptions,
53    max_cacheable_size: u64,
54}
55
56impl<S: CacheStorage> Clone for Cache<S> {
57    fn clone(&self) -> Self {
58        Self {
59            storage: Arc::clone(&self.storage),
60            options: self.options,
61            max_cacheable_size: self.max_cacheable_size,
62        }
63    }
64}
65
66impl<S: CacheStorage> Cache<S> {
67    /// Construct a cache handler with default options
68    /// ([`CacheOptions::default`]) and a 16 MiB body-size cap.
69    pub fn new(storage: S) -> Self {
70        Self {
71            storage: Arc::new(storage),
72            options: CacheOptions::default(),
73            max_cacheable_size: DEFAULT_MAX_CACHEABLE_SIZE,
74        }
75    }
76
77    /// Replace the cache options.
78    pub fn with_options(mut self, options: CacheOptions) -> Self {
79        self.options = options;
80        self
81    }
82
83    /// Mark this cache as a *shared cache* (proxy/CDN). Equivalent to
84    /// `with_options` with `shared: true`.
85    pub fn shared(mut self) -> Self {
86        self.options.shared = true;
87        self
88    }
89
90    /// Set the cap on response body bytes the cache will store.
91    /// Responses larger than this pass through but are not stored. If
92    /// the cap is exceeded mid-stream, the cache write is aborted and
93    /// the remainder of the body passes through unmodified.
94    pub fn with_max_cacheable_size(mut self, max: u64) -> Self {
95        self.max_cacheable_size = max;
96        self
97    }
98
99    /// Borrow the storage backend.
100    pub fn storage(&self) -> &S {
101        &self.storage
102    }
103}
104
105// State stashed in the conn's typeset by `run` for `before_send` to pick up.
106enum CacheCtx<E: StoredEntry> {
107    /// Cache hit — `run` already populated a synthetic response and halted.
108    Hit,
109    /// Stored entry was stale and a conditional revalidation request has been spliced onto the
110    /// conn. `before_send` reconciles the downstream handler's reply (304 vs 200) with the stored
111    /// entry.
112    Revalidation { stored: E, key: CacheKey },
113    /// Cache miss — no stored entry matched. If the response is storable, `before_send` will
114    /// install a streaming tee into storage and pass the body through.
115    Miss { key: CacheKey },
116    /// Unsafe method (POST/PUT/DELETE/...). On a non-error response, `before_send` invalidates the
117    /// target URI per RFC 9111 §4.4.
118    Unsafe { url: Url },
119}
120
121impl<E: StoredEntry> std::fmt::Debug for CacheCtx<E> {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        match self {
124            Self::Hit => f.write_str("Hit"),
125            Self::Revalidation { key, .. } => f
126                .debug_struct("Revalidation")
127                .field("key", key)
128                .finish_non_exhaustive(),
129            Self::Miss { key } => f.debug_struct("Miss").field("key", key).finish(),
130            Self::Unsafe { url } => f.debug_struct("Unsafe").field("url", url).finish(),
131        }
132    }
133}
134
135// Build a `Url` from the request's effective scheme, host, and path-and-query. `is_secure()`
136// reflects `trillium-forwarding`'s view of TLS termination, which is the right scheme to key on
137// for a shared cache fronting trusted reverse proxies.
138fn url_from_conn(conn: &Conn) -> Option<Url> {
139    let scheme = if conn.is_secure() { "https" } else { "http" };
140    let host = conn.host()?;
141    let path_and_query = conn.path_and_query();
142    Url::parse(&format!("{scheme}://{host}{path_and_query}")).ok()
143}
144
145impl<S: CacheStorage> Handler for Cache<S> {
146    async fn run(&self, mut conn: Conn) -> Conn {
147        let method = conn.method();
148        let Some(url) = url_from_conn(&conn) else {
149            log::trace!("cache: no host on request, passing through without caching");
150            return conn;
151        };
152        let key = CacheKey::new(method, url.clone());
153        log::trace!("cache: run {method} {url}");
154
155        // RFC 9111 §4.4: don't read from cache for unsafe methods;
156        // possibly invalidate after the round-trip.
157        if !method.is_safe() {
158            log::trace!("cache: unsafe method {method}, bypassing cache read");
159            return conn.with_state(CacheCtx::<S::StoredEntry>::Unsafe { url });
160        }
161
162        let now = SystemTime::now();
163        let entries = self.storage.get(&key).await;
164        log::trace!("cache: {} stored candidate(s) for {key}", entries.len());
165
166        for entry in entries {
167            match entry.policy().before_request(conn.request_headers(), now) {
168                BeforeRequest::Fresh(cached) => {
169                    log::trace!("cache: hit (fresh) for {key}, serving cached response");
170                    *conn.response_headers_mut() = cached.headers;
171                    let body = match entry.open().await {
172                        Ok(b) => b,
173                        Err(e) => {
174                            log::warn!(
175                                "cache: open for hit failed for {key}: {e}, passing through"
176                            );
177                            return conn;
178                        }
179                    };
180                    return conn
181                        .with_state(CacheCtx::<S::StoredEntry>::Hit)
182                        .with_status(cached.status)
183                        .with_body(body)
184                        .halt();
185                }
186
187                BeforeRequest::NotModified(cached) => {
188                    // RFC 9111 §4.3.2 + RFC 9110 §13.2.2: client's conditional already matches
189                    // the cached entry. Send 304 with stripped headers and no body.
190                    log::trace!("cache: hit (fresh, conditional matches) for {key}, serving 304");
191                    *conn.response_headers_mut() = cached.headers;
192                    return conn
193                        .with_state(CacheCtx::<S::StoredEntry>::Hit)
194                        .with_status(cached.status)
195                        .with_body(Body::default())
196                        .halt();
197                }
198
199                BeforeRequest::Stale {
200                    request_headers,
201                    matches: true,
202                } => {
203                    // RFC 9111 §4.3: splice conditional-revalidation headers onto the request;
204                    // let the downstream handler run; reconcile in `before_send`.
205                    //
206                    // 0.1 caveat: no `stale-while-revalidate` — we always do synchronous
207                    // revalidation here, even for SWR-eligible entries.
208                    log::trace!("cache: stale for {key}, sending conditional revalidation request");
209                    *conn.request_headers_mut() = request_headers;
210                    return conn.with_state(CacheCtx::Revalidation { stored: entry, key });
211                }
212
213                BeforeRequest::Stale { matches: false, .. } => {
214                    log::trace!("cache: candidate vary-mismatch for {key}, trying next");
215                    continue;
216                }
217            }
218        }
219
220        log::trace!("cache: miss for {key}, forwarding to downstream handler");
221        conn.with_state(CacheCtx::<S::StoredEntry>::Miss { key })
222    }
223
224    async fn before_send(&self, mut conn: Conn) -> Conn {
225        let Some(ctx) = conn.take_state::<CacheCtx<S::StoredEntry>>() else {
226            return conn;
227        };
228
229        match ctx {
230            CacheCtx::Hit => conn,
231            CacheCtx::Revalidation { stored, key } => {
232                let now = SystemTime::now();
233                let origin_failed = conn.status().is_some_and(|s| s.is_server_error());
234                if origin_failed && stored.policy().is_sie_eligible(now) {
235                    log::trace!(
236                        "cache: stale-if-error recovery for {} (downstream {:?}), serving stale",
237                        conn.method(),
238                        conn.status()
239                    );
240                    return apply_stale(conn, stored, now).await;
241                }
242                if conn.status().is_none() {
243                    log::trace!("cache: downstream produced no status, passing through");
244                    return conn;
245                }
246                self.handle_revalidation(conn, stored, key).await
247            }
248            CacheCtx::Miss { key } => {
249                if conn.status().is_none() {
250                    log::trace!("cache: downstream produced no status, passing through");
251                    return conn;
252                }
253                self.handle_miss(conn, key).await
254            }
255            CacheCtx::Unsafe { url } => {
256                let Some(status) = conn.status() else {
257                    return conn;
258                };
259                if status.is_success() || status.is_redirection() {
260                    log::trace!(
261                        "cache: unsafe method {} → {}, invalidating GET and HEAD entries for {url}",
262                        conn.method(),
263                        status
264                    );
265                    self.invalidate_url(&url).await;
266
267                    // §4.4: also invalidate URIs in `Location` and `Content-Location` headers
268                    // when their host matches (DoS prevention).
269                    for header in [KnownHeaderName::Location, KnownHeaderName::ContentLocation] {
270                        let Some(value) = conn.response_headers().get_str(header) else {
271                            continue;
272                        };
273                        let Ok(target) = url.join(value) else {
274                            continue;
275                        };
276                        if target.host_str() != url.host_str() {
277                            continue;
278                        }
279                        log::trace!(
280                            "cache: unsafe method secondary invalidation via {header}: {target}"
281                        );
282                        self.invalidate_url(&target).await;
283                    }
284                }
285                conn
286            }
287        }
288    }
289}
290
291impl<S: CacheStorage> Cache<S> {
292    async fn invalidate_url(&self, url: &Url) {
293        self.storage
294            .invalidate(&CacheKey::new(Method::Get, url.clone()))
295            .await;
296        self.storage
297            .invalidate(&CacheKey::new(Method::Head, url.clone()))
298            .await;
299    }
300
301    async fn handle_revalidation(
302        &self,
303        mut conn: Conn,
304        mut stored: S::StoredEntry,
305        key: CacheKey,
306    ) -> Conn {
307        let now = SystemTime::now();
308        let status = conn.status().expect("checked above");
309        match stored.policy().after_response(
310            conn.request_headers(),
311            status,
312            conn.response_headers(),
313            now,
314        ) {
315            AfterResponse::NotModified(new_policy, cached_response) => {
316                log::trace!(
317                    "cache: revalidation 304 for {key}, reusing stored body and refreshing entry"
318                );
319                if let Err(e) = stored.refresh_policy(new_policy).await {
320                    log::warn!("cache: refresh_policy failed for {key}: {e}");
321                }
322                let body = match stored.open().await {
323                    Ok(b) => b,
324                    Err(e) => {
325                        log::warn!("cache: open after 304 failed for {key}: {e}, passing through");
326                        return conn;
327                    }
328                };
329                *conn.response_headers_mut() = cached_response.headers;
330                conn.set_status(cached_response.status);
331                conn.set_body(body);
332                conn
333            }
334            AfterResponse::Modified => {
335                // Drop the stored entry; treat as a fresh miss against the same key. The new
336                // entry replaces any stored variant with the same Vary signature.
337                drop(stored);
338                self.handle_miss(conn, key).await
339            }
340        }
341    }
342
343    async fn handle_miss(&self, mut conn: Conn, key: CacheKey) -> Conn {
344        let status = conn.status().expect("checked above");
345        if !CachePolicy::is_storable(
346            conn.method(),
347            conn.request_headers(),
348            status,
349            conn.response_headers(),
350            &self.options,
351        ) {
352            log::trace!("cache: miss for {key}, response not storable, passing through");
353            return conn;
354        }
355
356        // Skip the put entirely when content-length is known and already over cap.
357        if let Some(body_ref) = conn.response_body()
358            && let Some(len) = body_ref.len()
359            && len > self.max_cacheable_size
360        {
361            log::trace!(
362                "cache: miss for {key}, body {len} > max {}, not caching",
363                self.max_cacheable_size
364            );
365            return conn;
366        }
367
368        let policy = CachePolicy::new(
369            conn.method(),
370            conn.request_headers(),
371            status,
372            conn.response_headers().clone(),
373            SystemTime::now(),
374            self.options,
375        );
376        let put_handle = match self.storage.put(key.clone(), policy).await {
377            Ok(h) => h,
378            Err(e) => {
379                log::warn!("cache: put({key}) failed: {e}, passing through");
380                return conn;
381            }
382        };
383
384        let Some(body) = conn.take_response_body() else {
385            log::trace!("cache: miss for {key}, no body, passing through");
386            return conn;
387        };
388        let len = body.len();
389        log::trace!("cache: miss for {key}, streaming through tee");
390        let body = body.without_chunked_framing();
391        let tee = TeeingReader::new(body, put_handle, self.max_cacheable_size);
392        conn.set_body(Body::new_with_trailers(tee, len));
393        conn
394    }
395}
396
397// RFC 5861 stale-if-error recovery: replace conn's response state with the stored entry's.
398async fn apply_stale<E: StoredEntry>(mut conn: Conn, stored: E, now: SystemTime) -> Conn {
399    let cached = stored.policy().cached_response(now);
400    let body = match stored.open().await {
401        Ok(b) => b,
402        Err(e) => {
403            log::warn!("cache: open for stale serve failed: {e}, passing through");
404            return conn;
405        }
406    };
407    *conn.response_headers_mut() = cached.headers;
408    conn.set_status(cached.status);
409    conn.set_body(body);
410    conn
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use crate::InMemoryStorage;
417    use std::sync::atomic::{AtomicUsize, Ordering};
418    use trillium_testing::{TestResult, TestServer, harness, test};
419
420    #[derive(Debug, Clone)]
421    struct CountingHandler {
422        counter: Arc<AtomicUsize>,
423        cache_control: &'static str,
424        etag: Option<&'static str>,
425    }
426
427    impl CountingHandler {
428        fn new(cache_control: &'static str) -> Self {
429            Self {
430                counter: Arc::new(AtomicUsize::new(0)),
431                cache_control,
432                etag: None,
433            }
434        }
435
436        fn with_etag(mut self, etag: &'static str) -> Self {
437            self.etag = Some(etag);
438            self
439        }
440    }
441
442    impl Handler for CountingHandler {
443        async fn run(&self, conn: Conn) -> Conn {
444            let n = self.counter.fetch_add(1, Ordering::SeqCst);
445            if let Some(etag) = self.etag
446                && conn.request_headers().get_str(KnownHeaderName::IfNoneMatch) == Some(etag)
447            {
448                return conn
449                    .with_response_header(KnownHeaderName::Etag, etag)
450                    .with_status(304)
451                    .halt();
452            }
453            let mut conn = conn
454                .with_response_header(KnownHeaderName::CacheControl, self.cache_control)
455                .ok(format!("body-{n}"));
456            if let Some(etag) = self.etag {
457                conn.response_headers_mut()
458                    .insert(KnownHeaderName::Etag, etag);
459            }
460            conn
461        }
462    }
463
464    fn cache_app(inner: CountingHandler) -> impl Handler {
465        (Cache::new(InMemoryStorage::new()), inner)
466    }
467
468    #[test(harness)]
469    async fn first_request_misses_subsequent_request_hits() -> TestResult {
470        let inner = CountingHandler::new("max-age=600");
471        let counter = inner.counter.clone();
472        let app = TestServer::new(cache_app(inner)).await;
473
474        let r1 = app.get("/x").await;
475        r1.assert_ok().assert_body("body-0");
476
477        let r2 = app.get("/x").await;
478        r2.assert_ok().assert_body("body-0");
479        assert_eq!(
480            counter.load(Ordering::SeqCst),
481            1,
482            "inner handler only hit once"
483        );
484        Ok(())
485    }
486
487    #[test(harness)]
488    async fn different_urls_dont_collide() -> TestResult {
489        let inner = CountingHandler::new("max-age=600");
490        let counter = inner.counter.clone();
491        let app = TestServer::new(cache_app(inner)).await;
492
493        app.get("/a").await.assert_body("body-0");
494        app.get("/b").await.assert_body("body-1");
495        assert_eq!(counter.load(Ordering::SeqCst), 2);
496        Ok(())
497    }
498
499    #[test(harness)]
500    async fn no_store_response_is_not_cached() -> TestResult {
501        let inner = CountingHandler::new("no-store");
502        let counter = inner.counter.clone();
503        let app = TestServer::new(cache_app(inner)).await;
504
505        app.get("/x").await.assert_body("body-0");
506        app.get("/x").await.assert_body("body-1");
507        assert_eq!(counter.load(Ordering::SeqCst), 2);
508        Ok(())
509    }
510
511    #[test(harness)]
512    async fn post_invalidates_existing_entry() -> TestResult {
513        let inner = CountingHandler::new("max-age=600");
514        let counter = inner.counter.clone();
515        let app = TestServer::new(cache_app(inner)).await;
516
517        app.get("/x").await.assert_body("body-0");
518        let _ = app.post("/x").await;
519        app.get("/x").await.assert_body("body-2");
520        assert_eq!(counter.load(Ordering::SeqCst), 3);
521        Ok(())
522    }
523
524    // §4.3 + §3.2: stored stale → revalidation → 304 → reuse cached body.
525    #[test(harness)]
526    async fn stale_with_etag_revalidates_to_304() -> TestResult {
527        let inner = CountingHandler::new("max-age=0").with_etag(r#""v1""#);
528        let counter = inner.counter.clone();
529        let app = TestServer::new(cache_app(inner)).await;
530
531        app.get("/x").await.assert_body("body-0");
532        assert_eq!(counter.load(Ordering::SeqCst), 1);
533
534        // Stale: cache sends conditional revalidation, inner returns 304, cache serves
535        // the cached body with original status.
536        let r2 = app.get("/x").await;
537        r2.assert_ok().assert_body("body-0");
538        assert_eq!(counter.load(Ordering::SeqCst), 2);
539        Ok(())
540    }
541
542    #[test(harness)]
543    async fn vary_isolates_entries_by_request_header() -> TestResult {
544        #[derive(Debug, Clone, Default)]
545        struct VaryHandler(Arc<AtomicUsize>);
546        impl Handler for VaryHandler {
547            async fn run(&self, conn: Conn) -> Conn {
548                self.0.fetch_add(1, Ordering::SeqCst);
549                let ae = conn
550                    .request_headers()
551                    .get_str(KnownHeaderName::AcceptEncoding)
552                    .unwrap_or("none")
553                    .to_string();
554                conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
555                    .with_response_header(KnownHeaderName::Vary, "Accept-Encoding")
556                    .ok(format!("body-for-{ae}"))
557            }
558        }
559
560        let inner = VaryHandler::default();
561        let counter = inner.0.clone();
562        let app = TestServer::new((Cache::new(InMemoryStorage::new()), inner)).await;
563
564        app.get("/x")
565            .with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
566            .await
567            .assert_body("body-for-gzip");
568        app.get("/x")
569            .with_request_header(KnownHeaderName::AcceptEncoding, "br")
570            .await
571            .assert_body("body-for-br");
572        app.get("/x")
573            .with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
574            .await
575            .assert_body("body-for-gzip");
576
577        assert_eq!(counter.load(Ordering::SeqCst), 2);
578        Ok(())
579    }
580
581    #[test(harness)]
582    async fn oversized_body_is_served_but_not_cached() -> TestResult {
583        let inner = CountingHandler::new("max-age=600");
584        let counter = inner.counter.clone();
585        // "body-N" is 6 bytes — cap at 3 so nothing is stored.
586        let app = TestServer::new((
587            Cache::new(InMemoryStorage::new()).with_max_cacheable_size(3),
588            inner,
589        ))
590        .await;
591
592        app.get("/x").await.assert_body("body-0");
593        app.get("/x").await.assert_body("body-1");
594        assert_eq!(counter.load(Ordering::SeqCst), 2);
595        Ok(())
596    }
597
598    // RFC 5861 stale-if-error: downstream returns 5xx, cache serves stored stale entry.
599    #[test(harness)]
600    async fn sie_serves_stale_on_5xx() -> TestResult {
601        // First request populates the cache with a stale-if-error window. Subsequent requests
602        // get a 5xx from the inner handler.
603        #[derive(Debug, Clone)]
604        struct FlakyHandler(Arc<AtomicUsize>);
605        impl Handler for FlakyHandler {
606            async fn run(&self, conn: Conn) -> Conn {
607                let n = self.0.fetch_add(1, Ordering::SeqCst);
608                if n == 0 {
609                    conn.with_response_header(
610                        KnownHeaderName::CacheControl,
611                        "max-age=0, stale-if-error=3600",
612                    )
613                    .ok("stable")
614                } else {
615                    conn.with_status(500).halt()
616                }
617            }
618        }
619
620        let inner = FlakyHandler(Arc::new(AtomicUsize::new(0)));
621        let counter = inner.0.clone();
622        let app = TestServer::new((Cache::new(InMemoryStorage::new()), inner)).await;
623
624        app.get("/x").await.assert_ok().assert_body("stable");
625        assert_eq!(counter.load(Ordering::SeqCst), 1);
626
627        let r2 = app.get("/x").await;
628        r2.assert_ok().assert_body("stable");
629        assert_eq!(counter.load(Ordering::SeqCst), 2);
630        Ok(())
631    }
632}