rust_web_server/cache/
mod.rs1#[cfg(test)]
28mod tests;
29
30use std::collections::{HashMap, VecDeque};
31use std::sync::atomic::{AtomicU64, Ordering};
32use std::sync::{Arc, Mutex};
33use std::time::{Duration, Instant};
34
35use crate::application::Application;
36use crate::header::Header;
37use crate::middleware::Middleware;
38use crate::request::{METHOD, Request};
39use crate::response::Response;
40use crate::server::ConnectionInfo;
41
42struct CachedEntry {
45 response: Response,
46 inserted_at: Instant,
47}
48
49struct CacheStore {
50 entries: HashMap<String, CachedEntry>,
51 order: VecDeque<String>,
53}
54
55impl CacheStore {
56 fn new() -> Self {
57 CacheStore { entries: HashMap::new(), order: VecDeque::new() }
58 }
59
60 fn get(&self, key: &str, ttl: Duration) -> Option<&CachedEntry> {
61 self.entries.get(key).filter(|e| e.inserted_at.elapsed() < ttl)
62 }
63
64 fn insert(&mut self, key: String, entry: CachedEntry, capacity: usize) {
65 if self.entries.contains_key(&key) {
67 self.entries.insert(key, entry);
68 return;
69 }
70 if self.entries.len() >= capacity {
73 if let Some(oldest) = self.order.pop_front() {
74 self.entries.remove(&oldest);
75 }
76 }
77 self.order.push_back(key.clone());
78 self.entries.insert(key, entry);
79 }
80
81 fn purge_expired(&mut self, ttl: Duration) {
82 let expired: Vec<String> = self.entries.iter()
83 .filter(|(_, e)| e.inserted_at.elapsed() >= ttl)
84 .map(|(k, _)| k.clone())
85 .collect();
86 for k in &expired {
87 self.entries.remove(k);
88 self.order.retain(|o| o != k);
89 }
90 }
91}
92
93#[derive(Clone)]
104pub struct CacheLayer {
105 store: Arc<Mutex<CacheStore>>,
106 hits: Arc<AtomicU64>,
107 misses: Arc<AtomicU64>,
108 capacity: usize,
109 ttl: Duration,
110 vary_headers: Vec<String>,
111}
112
113impl CacheLayer {
114 pub fn memory(capacity: usize) -> Self {
118 CacheLayer {
119 store: Arc::new(Mutex::new(CacheStore::new())),
120 hits: Arc::new(AtomicU64::new(0)),
121 misses: Arc::new(AtomicU64::new(0)),
122 capacity,
123 ttl: Duration::from_secs(60),
124 vary_headers: vec![],
125 }
126 }
127
128 pub fn hits(&self) -> u64 {
130 self.hits.load(Ordering::Relaxed)
131 }
132
133 pub fn misses(&self) -> u64 {
135 self.misses.load(Ordering::Relaxed)
136 }
137
138 pub fn size(&self) -> usize {
141 self.store.lock().unwrap().entries.len()
142 }
143
144 pub fn clear(&self) {
146 let mut guard = self.store.lock().unwrap();
147 guard.entries.clear();
148 guard.order.clear();
149 }
150
151 pub fn ttl(mut self, secs: u64) -> Self {
153 self.ttl = Duration::from_secs(secs);
154 self
155 }
156
157 pub fn vary_by_header(mut self, name: &str) -> Self {
171 self.vary_headers.push(name.to_ascii_lowercase());
172 self
173 }
174
175 fn store(&self) -> &Mutex<CacheStore> {
176 &self.store
177 }
178
179 fn cache_key(&self, request: &Request) -> String {
181 let mut key = request.request_uri.clone();
182 for vh in &self.vary_headers {
183 let val = request.headers.iter()
184 .find(|h| h.name.eq_ignore_ascii_case(vh))
185 .map(|h| h.value.as_str())
186 .unwrap_or("");
187 key.push('\x00');
188 key.push_str(val);
189 }
190 key
191 }
192
193 fn request_bypasses_cache(request: &Request) -> bool {
196 request.headers.iter().any(|h| {
197 h.name.eq_ignore_ascii_case(Header::_CACHE_CONTROL)
198 && h.value.to_ascii_lowercase().contains("no-cache")
199 })
200 }
201
202 fn response_is_cacheable(response: &Response) -> bool {
204 if response.status_code < 200 || response.status_code >= 300 {
205 return false;
206 }
207 !response.headers.iter().any(|h| {
208 if !h.name.eq_ignore_ascii_case(Header::_CACHE_CONTROL) {
209 return false;
210 }
211 let v = h.value.to_ascii_lowercase();
212 v.contains("no-store") || v.contains("private")
213 })
214 }
215
216 fn age_secs(entry: &CachedEntry) -> u64 {
218 entry.inserted_at.elapsed().as_secs()
219 }
220
221 fn cached_response(entry: &CachedEntry) -> Response {
223 let mut response = entry.response.clone();
224 let age = Self::age_secs(entry);
225 if let Some(h) = response.headers.iter_mut().find(|h| h.name.eq_ignore_ascii_case("Age")) {
227 h.value = age.to_string();
228 } else {
229 response.headers.push(Header { name: "Age".to_string(), value: age.to_string() });
230 }
231 response
232 }
233}
234
235impl Middleware for CacheLayer {
236 fn handle(
237 &self,
238 request: &Request,
239 connection: &ConnectionInfo,
240 next: &dyn Application,
241 ) -> Result<Response, String> {
242 if request.method != METHOD.get {
244 return next.execute(request, connection);
245 }
246
247 let key = self.cache_key(request);
248 let bypass = Self::request_bypasses_cache(request);
249
250 if !bypass {
251 let guard = self.store().lock().unwrap();
253 if let Some(entry) = guard.get(&key, self.ttl) {
254 self.hits.fetch_add(1, Ordering::Relaxed);
255 return Ok(Self::cached_response(entry));
256 }
257 }
258 self.misses.fetch_add(1, Ordering::Relaxed);
259
260 let response = next.execute(request, connection)?;
262
263 if Self::response_is_cacheable(&response) {
264 let mut guard = self.store().lock().unwrap();
265 guard.purge_expired(self.ttl);
266 guard.insert(
267 key,
268 CachedEntry { response: response.clone(), inserted_at: Instant::now() },
269 self.capacity,
270 );
271 }
272
273 Ok(response)
274 }
275}