Skip to main content

rust_web_server/cache/
mod.rs

1//! In-memory response cache middleware.
2//!
3//! [`CacheLayer`] implements [`Middleware`] and short-circuits the inner
4//! application for cacheable `GET` responses within their TTL.  Entries are
5//! bounded by a configurable capacity; the oldest entry is evicted when the
6//! store is full and no expired entries remain.
7//!
8//! # What is cached
9//!
10//! - Method: **GET only**; all other methods bypass the cache.
11//! - Status: 2xx responses (200, 201, 203, 204, 206, …).
12//! - Response `Cache-Control: no-store` or `private` — **not** cached.
13//! - Request `Cache-Control: no-cache` — cache is bypassed, handler is called,
14//!   but the fresh response **is** stored (revalidation).
15//!
16//! # Example
17//!
18//! ```rust,no_run
19//! use rust_web_server::app::App;
20//! use rust_web_server::cache::CacheLayer;
21//! use rust_web_server::core::New;
22//!
23//! let app = App::new()
24//!     .wrap(CacheLayer::memory(1000).ttl(60).vary_by_header("Accept"));
25//! ```
26
27#[cfg(test)]
28mod tests;
29
30use std::collections::{HashMap, VecDeque};
31use std::sync::{Mutex, OnceLock};
32use std::time::{Duration, Instant};
33
34use crate::application::Application;
35use crate::header::Header;
36use crate::middleware::Middleware;
37use crate::request::{METHOD, Request};
38use crate::response::Response;
39use crate::server::ConnectionInfo;
40
41// ── cache store ───────────────────────────────────────────────────────────────
42
43struct CachedEntry {
44    response: Response,
45    inserted_at: Instant,
46}
47
48struct CacheStore {
49    entries: HashMap<String, CachedEntry>,
50    /// Insertion order — front is oldest; used for capacity eviction.
51    order: VecDeque<String>,
52}
53
54impl CacheStore {
55    fn new() -> Self {
56        CacheStore { entries: HashMap::new(), order: VecDeque::new() }
57    }
58
59    fn get(&self, key: &str, ttl: Duration) -> Option<&CachedEntry> {
60        self.entries.get(key).filter(|e| e.inserted_at.elapsed() < ttl)
61    }
62
63    fn insert(&mut self, key: String, entry: CachedEntry, capacity: usize) {
64        // Update in place without disturbing insertion order.
65        if self.entries.contains_key(&key) {
66            self.entries.insert(key, entry);
67            return;
68        }
69        // `purge_expired` is called by the caller before `insert`, so any
70        // remaining entries are still live. Evict the oldest if we're full.
71        if self.entries.len() >= capacity {
72            if let Some(oldest) = self.order.pop_front() {
73                self.entries.remove(&oldest);
74            }
75        }
76        self.order.push_back(key.clone());
77        self.entries.insert(key, entry);
78    }
79
80    fn purge_expired(&mut self, ttl: Duration) {
81        let expired: Vec<String> = self.entries.iter()
82            .filter(|(_, e)| e.inserted_at.elapsed() >= ttl)
83            .map(|(k, _)| k.clone())
84            .collect();
85        for k in &expired {
86            self.entries.remove(k);
87            self.order.retain(|o| o != k);
88        }
89    }
90}
91
92// ── CacheLayer ────────────────────────────────────────────────────────────────
93
94/// An in-memory response cache middleware.
95///
96/// Construct with [`CacheLayer::memory`] and configure with the builder methods
97/// [`ttl`](CacheLayer::ttl) and [`vary_by_header`](CacheLayer::vary_by_header).
98pub struct CacheLayer {
99    store: OnceLock<Mutex<CacheStore>>,
100    capacity: usize,
101    ttl: Duration,
102    vary_headers: Vec<String>,
103}
104
105impl CacheLayer {
106    /// Create a new in-memory cache bounded to `capacity` entries.
107    ///
108    /// Default TTL is **60 seconds**. Adjust with [`.ttl()`](CacheLayer::ttl).
109    pub fn memory(capacity: usize) -> Self {
110        CacheLayer {
111            store: OnceLock::new(),
112            capacity,
113            ttl: Duration::from_secs(60),
114            vary_headers: vec![],
115        }
116    }
117
118    /// Set the time-to-live for cached entries.
119    pub fn ttl(mut self, secs: u64) -> Self {
120        self.ttl = Duration::from_secs(secs);
121        self
122    }
123
124    /// Include a request header in the cache key so that different values of
125    /// that header produce separate cache entries.
126    ///
127    /// Header name matching is case-insensitive. Call multiple times to vary
128    /// by more than one header.
129    ///
130    /// ```rust,no_run
131    /// use rust_web_server::cache::CacheLayer;
132    ///
133    /// let layer = CacheLayer::memory(500)
134    ///     .vary_by_header("Accept")
135    ///     .vary_by_header("Accept-Language");
136    /// ```
137    pub fn vary_by_header(mut self, name: &str) -> Self {
138        self.vary_headers.push(name.to_ascii_lowercase());
139        self
140    }
141
142    fn store(&self) -> &Mutex<CacheStore> {
143        self.store.get_or_init(|| Mutex::new(CacheStore::new()))
144    }
145
146    /// Build a cache key from the request URI and any configured vary headers.
147    fn cache_key(&self, request: &Request) -> String {
148        let mut key = request.request_uri.clone();
149        for vh in &self.vary_headers {
150            let val = request.headers.iter()
151                .find(|h| h.name.eq_ignore_ascii_case(vh))
152                .map(|h| h.value.as_str())
153                .unwrap_or("");
154            key.push('\x00');
155            key.push_str(val);
156        }
157        key
158    }
159
160    /// `true` when the request carries `Cache-Control: no-cache`, meaning the
161    /// client wants a fresh response (but we may still store the result).
162    fn request_bypasses_cache(request: &Request) -> bool {
163        request.headers.iter().any(|h| {
164            h.name.eq_ignore_ascii_case(Header::_CACHE_CONTROL)
165                && h.value.to_ascii_lowercase().contains("no-cache")
166        })
167    }
168
169    /// `true` when the response may be stored in the cache.
170    fn response_is_cacheable(response: &Response) -> bool {
171        if response.status_code < 200 || response.status_code >= 300 {
172            return false;
173        }
174        !response.headers.iter().any(|h| {
175            if !h.name.eq_ignore_ascii_case(Header::_CACHE_CONTROL) {
176                return false;
177            }
178            let v = h.value.to_ascii_lowercase();
179            v.contains("no-store") || v.contains("private")
180        })
181    }
182
183    /// Age of the entry in whole seconds, capped at u64::MAX.
184    fn age_secs(entry: &CachedEntry) -> u64 {
185        entry.inserted_at.elapsed().as_secs()
186    }
187
188    /// Build a response from a cache hit, injecting an `Age` header.
189    fn cached_response(entry: &CachedEntry) -> Response {
190        let mut response = entry.response.clone();
191        let age = Self::age_secs(entry);
192        // Replace or add the Age header.
193        if let Some(h) = response.headers.iter_mut().find(|h| h.name.eq_ignore_ascii_case("Age")) {
194            h.value = age.to_string();
195        } else {
196            response.headers.push(Header { name: "Age".to_string(), value: age.to_string() });
197        }
198        response
199    }
200}
201
202impl Middleware for CacheLayer {
203    fn handle(
204        &self,
205        request: &Request,
206        connection: &ConnectionInfo,
207        next: &dyn Application,
208    ) -> Result<Response, String> {
209        // Only cache GET requests.
210        if request.method != METHOD.get {
211            return next.execute(request, connection);
212        }
213
214        let key = self.cache_key(request);
215        let bypass = Self::request_bypasses_cache(request);
216
217        if !bypass {
218            // Check for a valid cache hit.
219            let guard = self.store().lock().unwrap();
220            if let Some(entry) = guard.get(&key, self.ttl) {
221                return Ok(Self::cached_response(entry));
222            }
223        }
224
225        // Cache miss (or bypass): call the handler.
226        let response = next.execute(request, connection)?;
227
228        if Self::response_is_cacheable(&response) {
229            let mut guard = self.store().lock().unwrap();
230            guard.purge_expired(self.ttl);
231            guard.insert(
232                key,
233                CachedEntry { response: response.clone(), inserted_at: Instant::now() },
234                self.capacity,
235            );
236        }
237
238        Ok(response)
239    }
240}