rust_web_server/cache/
mod.rs1#[cfg(test)]
28mod tests;
29
30use std::collections::{HashMap, VecDeque};
31use std::sync::{Mutex, OnceLock};
32use std::time::{Duration, Instant};
33
34use crate::application::Application;
35use crate::header::Header;
36use crate::middleware::Middleware;
37use crate::request::{METHOD, Request};
38use crate::response::Response;
39use crate::server::ConnectionInfo;
40
41struct CachedEntry {
44 response: Response,
45 inserted_at: Instant,
46}
47
48struct CacheStore {
49 entries: HashMap<String, CachedEntry>,
50 order: VecDeque<String>,
52}
53
54impl CacheStore {
55 fn new() -> Self {
56 CacheStore { entries: HashMap::new(), order: VecDeque::new() }
57 }
58
59 fn get(&self, key: &str, ttl: Duration) -> Option<&CachedEntry> {
60 self.entries.get(key).filter(|e| e.inserted_at.elapsed() < ttl)
61 }
62
63 fn insert(&mut self, key: String, entry: CachedEntry, capacity: usize) {
64 if self.entries.contains_key(&key) {
66 self.entries.insert(key, entry);
67 return;
68 }
69 if self.entries.len() >= capacity {
72 if let Some(oldest) = self.order.pop_front() {
73 self.entries.remove(&oldest);
74 }
75 }
76 self.order.push_back(key.clone());
77 self.entries.insert(key, entry);
78 }
79
80 fn purge_expired(&mut self, ttl: Duration) {
81 let expired: Vec<String> = self.entries.iter()
82 .filter(|(_, e)| e.inserted_at.elapsed() >= ttl)
83 .map(|(k, _)| k.clone())
84 .collect();
85 for k in &expired {
86 self.entries.remove(k);
87 self.order.retain(|o| o != k);
88 }
89 }
90}
91
92pub struct CacheLayer {
99 store: OnceLock<Mutex<CacheStore>>,
100 capacity: usize,
101 ttl: Duration,
102 vary_headers: Vec<String>,
103}
104
105impl CacheLayer {
106 pub fn memory(capacity: usize) -> Self {
110 CacheLayer {
111 store: OnceLock::new(),
112 capacity,
113 ttl: Duration::from_secs(60),
114 vary_headers: vec![],
115 }
116 }
117
118 pub fn ttl(mut self, secs: u64) -> Self {
120 self.ttl = Duration::from_secs(secs);
121 self
122 }
123
124 pub fn vary_by_header(mut self, name: &str) -> Self {
138 self.vary_headers.push(name.to_ascii_lowercase());
139 self
140 }
141
142 fn store(&self) -> &Mutex<CacheStore> {
143 self.store.get_or_init(|| Mutex::new(CacheStore::new()))
144 }
145
146 fn cache_key(&self, request: &Request) -> String {
148 let mut key = request.request_uri.clone();
149 for vh in &self.vary_headers {
150 let val = request.headers.iter()
151 .find(|h| h.name.eq_ignore_ascii_case(vh))
152 .map(|h| h.value.as_str())
153 .unwrap_or("");
154 key.push('\x00');
155 key.push_str(val);
156 }
157 key
158 }
159
160 fn request_bypasses_cache(request: &Request) -> bool {
163 request.headers.iter().any(|h| {
164 h.name.eq_ignore_ascii_case(Header::_CACHE_CONTROL)
165 && h.value.to_ascii_lowercase().contains("no-cache")
166 })
167 }
168
169 fn response_is_cacheable(response: &Response) -> bool {
171 if response.status_code < 200 || response.status_code >= 300 {
172 return false;
173 }
174 !response.headers.iter().any(|h| {
175 if !h.name.eq_ignore_ascii_case(Header::_CACHE_CONTROL) {
176 return false;
177 }
178 let v = h.value.to_ascii_lowercase();
179 v.contains("no-store") || v.contains("private")
180 })
181 }
182
183 fn age_secs(entry: &CachedEntry) -> u64 {
185 entry.inserted_at.elapsed().as_secs()
186 }
187
188 fn cached_response(entry: &CachedEntry) -> Response {
190 let mut response = entry.response.clone();
191 let age = Self::age_secs(entry);
192 if let Some(h) = response.headers.iter_mut().find(|h| h.name.eq_ignore_ascii_case("Age")) {
194 h.value = age.to_string();
195 } else {
196 response.headers.push(Header { name: "Age".to_string(), value: age.to_string() });
197 }
198 response
199 }
200}
201
202impl Middleware for CacheLayer {
203 fn handle(
204 &self,
205 request: &Request,
206 connection: &ConnectionInfo,
207 next: &dyn Application,
208 ) -> Result<Response, String> {
209 if request.method != METHOD.get {
211 return next.execute(request, connection);
212 }
213
214 let key = self.cache_key(request);
215 let bypass = Self::request_bypasses_cache(request);
216
217 if !bypass {
218 let guard = self.store().lock().unwrap();
220 if let Some(entry) = guard.get(&key, self.ttl) {
221 return Ok(Self::cached_response(entry));
222 }
223 }
224
225 let response = next.execute(request, connection)?;
227
228 if Self::response_is_cacheable(&response) {
229 let mut guard = self.store().lock().unwrap();
230 guard.purge_expired(self.ttl);
231 guard.insert(
232 key,
233 CachedEntry { response: response.clone(), inserted_at: Instant::now() },
234 self.capacity,
235 );
236 }
237
238 Ok(response)
239 }
240}