Skip to main content

rust_web_server/cache/
mod.rs

1//! In-memory response cache middleware.
2//!
3//! [`CacheLayer`] implements [`Middleware`] and short-circuits the inner
4//! application for cacheable `GET` responses within their TTL.  Entries are
5//! bounded by a configurable capacity; the oldest entry is evicted when the
6//! store is full and no expired entries remain.
7//!
8//! # What is cached
9//!
10//! - Method: **GET only**; all other methods bypass the cache.
11//! - Status: 2xx responses (200, 201, 203, 204, 206, …).
12//! - Response `Cache-Control: no-store` or `private` — **not** cached.
13//! - Request `Cache-Control: no-cache` — cache is bypassed, handler is called,
14//!   but the fresh response **is** stored (revalidation).
15//!
16//! # Example
17//!
18//! ```rust,no_run
19//! use rust_web_server::app::App;
20//! use rust_web_server::cache::CacheLayer;
21//! use rust_web_server::core::New;
22//!
23//! let app = App::new()
24//!     .wrap(CacheLayer::memory(1000).ttl(60).vary_by_header("Accept"));
25//! ```
26
27#[cfg(test)]
28mod tests;
29
30use std::collections::{HashMap, VecDeque};
31use std::sync::atomic::{AtomicU64, Ordering};
32use std::sync::{Arc, Mutex};
33use std::time::{Duration, Instant};
34
35use crate::application::Application;
36use crate::header::Header;
37use crate::middleware::Middleware;
38use crate::request::{METHOD, Request};
39use crate::response::Response;
40use crate::server::ConnectionInfo;
41
42// ── cache store ───────────────────────────────────────────────────────────────
43
44struct CachedEntry {
45    response: Response,
46    inserted_at: Instant,
47}
48
49struct CacheStore {
50    entries: HashMap<String, CachedEntry>,
51    /// Insertion order — front is oldest; used for capacity eviction.
52    order: VecDeque<String>,
53}
54
55impl CacheStore {
56    fn new() -> Self {
57        CacheStore { entries: HashMap::new(), order: VecDeque::new() }
58    }
59
60    fn get(&self, key: &str, ttl: Duration) -> Option<&CachedEntry> {
61        self.entries.get(key).filter(|e| e.inserted_at.elapsed() < ttl)
62    }
63
64    fn insert(&mut self, key: String, entry: CachedEntry, capacity: usize) {
65        // Update in place without disturbing insertion order.
66        if self.entries.contains_key(&key) {
67            self.entries.insert(key, entry);
68            return;
69        }
70        // `purge_expired` is called by the caller before `insert`, so any
71        // remaining entries are still live. Evict the oldest if we're full.
72        if self.entries.len() >= capacity {
73            if let Some(oldest) = self.order.pop_front() {
74                self.entries.remove(&oldest);
75            }
76        }
77        self.order.push_back(key.clone());
78        self.entries.insert(key, entry);
79    }
80
81    fn purge_expired(&mut self, ttl: Duration) {
82        let expired: Vec<String> = self.entries.iter()
83            .filter(|(_, e)| e.inserted_at.elapsed() >= ttl)
84            .map(|(k, _)| k.clone())
85            .collect();
86        for k in &expired {
87            self.entries.remove(k);
88            self.order.retain(|o| o != k);
89        }
90    }
91}
92
93// ── CacheLayer ────────────────────────────────────────────────────────────────
94
95/// An in-memory response cache middleware.
96///
97/// Construct with [`CacheLayer::memory`] and configure with the builder methods
98/// [`ttl`](CacheLayer::ttl) and [`vary_by_header`](CacheLayer::vary_by_header).
99///
100/// `CacheLayer` is cheaply `Clone`-able — clones share the same underlying
101/// store, which lets you keep a handle for cache stats and invalidation while
102/// the other clone is used as middleware.
103#[derive(Clone)]
104pub struct CacheLayer {
105    store: Arc<Mutex<CacheStore>>,
106    hits: Arc<AtomicU64>,
107    misses: Arc<AtomicU64>,
108    capacity: usize,
109    ttl: Duration,
110    vary_headers: Vec<String>,
111}
112
113impl CacheLayer {
114    /// Create a new in-memory cache bounded to `capacity` entries.
115    ///
116    /// Default TTL is **60 seconds**. Adjust with [`.ttl()`](CacheLayer::ttl).
117    pub fn memory(capacity: usize) -> Self {
118        CacheLayer {
119            store: Arc::new(Mutex::new(CacheStore::new())),
120            hits: Arc::new(AtomicU64::new(0)),
121            misses: Arc::new(AtomicU64::new(0)),
122            capacity,
123            ttl: Duration::from_secs(60),
124            vary_headers: vec![],
125        }
126    }
127
128    /// Number of cache hits since the cache was created.
129    pub fn hits(&self) -> u64 {
130        self.hits.load(Ordering::Relaxed)
131    }
132
133    /// Number of cache misses since the cache was created.
134    pub fn misses(&self) -> u64 {
135        self.misses.load(Ordering::Relaxed)
136    }
137
138    /// Current number of entries in the store (including potentially expired ones
139    /// that haven't been purged yet).
140    pub fn size(&self) -> usize {
141        self.store.lock().unwrap().entries.len()
142    }
143
144    /// Remove all entries from the cache.
145    pub fn clear(&self) {
146        let mut guard = self.store.lock().unwrap();
147        guard.entries.clear();
148        guard.order.clear();
149    }
150
151    /// Set the time-to-live for cached entries.
152    pub fn ttl(mut self, secs: u64) -> Self {
153        self.ttl = Duration::from_secs(secs);
154        self
155    }
156
157    /// Include a request header in the cache key so that different values of
158    /// that header produce separate cache entries.
159    ///
160    /// Header name matching is case-insensitive. Call multiple times to vary
161    /// by more than one header.
162    ///
163    /// ```rust,no_run
164    /// use rust_web_server::cache::CacheLayer;
165    ///
166    /// let layer = CacheLayer::memory(500)
167    ///     .vary_by_header("Accept")
168    ///     .vary_by_header("Accept-Language");
169    /// ```
170    pub fn vary_by_header(mut self, name: &str) -> Self {
171        self.vary_headers.push(name.to_ascii_lowercase());
172        self
173    }
174
175    fn store(&self) -> &Mutex<CacheStore> {
176        &self.store
177    }
178
179    /// Build a cache key from the request URI and any configured vary headers.
180    fn cache_key(&self, request: &Request) -> String {
181        let mut key = request.request_uri.clone();
182        for vh in &self.vary_headers {
183            let val = request.headers.iter()
184                .find(|h| h.name.eq_ignore_ascii_case(vh))
185                .map(|h| h.value.as_str())
186                .unwrap_or("");
187            key.push('\x00');
188            key.push_str(val);
189        }
190        key
191    }
192
193    /// `true` when the request carries `Cache-Control: no-cache`, meaning the
194    /// client wants a fresh response (but we may still store the result).
195    fn request_bypasses_cache(request: &Request) -> bool {
196        request.headers.iter().any(|h| {
197            h.name.eq_ignore_ascii_case(Header::_CACHE_CONTROL)
198                && h.value.to_ascii_lowercase().contains("no-cache")
199        })
200    }
201
202    /// `true` when the response may be stored in the cache.
203    fn response_is_cacheable(response: &Response) -> bool {
204        if response.status_code < 200 || response.status_code >= 300 {
205            return false;
206        }
207        !response.headers.iter().any(|h| {
208            if !h.name.eq_ignore_ascii_case(Header::_CACHE_CONTROL) {
209                return false;
210            }
211            let v = h.value.to_ascii_lowercase();
212            v.contains("no-store") || v.contains("private")
213        })
214    }
215
216    /// Age of the entry in whole seconds, capped at u64::MAX.
217    fn age_secs(entry: &CachedEntry) -> u64 {
218        entry.inserted_at.elapsed().as_secs()
219    }
220
221    /// Build a response from a cache hit, injecting an `Age` header.
222    fn cached_response(entry: &CachedEntry) -> Response {
223        let mut response = entry.response.clone();
224        let age = Self::age_secs(entry);
225        // Replace or add the Age header.
226        if let Some(h) = response.headers.iter_mut().find(|h| h.name.eq_ignore_ascii_case("Age")) {
227            h.value = age.to_string();
228        } else {
229            response.headers.push(Header { name: "Age".to_string(), value: age.to_string() });
230        }
231        response
232    }
233}
234
235impl Middleware for CacheLayer {
236    fn handle(
237        &self,
238        request: &Request,
239        connection: &ConnectionInfo,
240        next: &dyn Application,
241    ) -> Result<Response, String> {
242        // Only cache GET requests.
243        if request.method != METHOD.get {
244            return next.execute(request, connection);
245        }
246
247        let key = self.cache_key(request);
248        let bypass = Self::request_bypasses_cache(request);
249
250        if !bypass {
251            // Check for a valid cache hit.
252            let guard = self.store().lock().unwrap();
253            if let Some(entry) = guard.get(&key, self.ttl) {
254                self.hits.fetch_add(1, Ordering::Relaxed);
255                return Ok(Self::cached_response(entry));
256            }
257        }
258        self.misses.fetch_add(1, Ordering::Relaxed);
259
260        // Cache miss (or bypass): call the handler.
261        let response = next.execute(request, connection)?;
262
263        if Self::response_is_cacheable(&response) {
264            let mut guard = self.store().lock().unwrap();
265            guard.purge_expired(self.ttl);
266            guard.insert(
267                key,
268                CachedEntry { response: response.clone(), inserted_at: Instant::now() },
269                self.capacity,
270            );
271        }
272
273        Ok(response)
274    }
275}