1mod config;
42mod error;
43mod events;
44mod layer;
45mod store;
46
47pub use config::{CacheConfig, CacheConfigBuilder, KeyExtractor};
48pub use error::CacheError;
49pub use events::CacheEvent;
50pub use layer::CacheLayer;
51
52use futures::future::BoxFuture;
53use std::hash::Hash;
54use std::sync::{Arc, Mutex};
55use std::task::{Context, Poll};
56use std::time::Instant;
57use store::CacheStore;
58use tower::Service;
59
60pub struct Cache<S, Req, K, Resp> {
69 inner: S,
70 config: Arc<CacheConfig<Req, K>>,
71 store: Arc<Mutex<CacheStore<K, Resp>>>,
72}
73
74impl<S, Req, K, Resp> Cache<S, Req, K, Resp>
75where
76 K: Hash + Eq,
77 Resp: Clone,
78{
79 pub fn new(inner: S, config: Arc<CacheConfig<Req, K>>) -> Self {
81 let store = Arc::new(Mutex::new(CacheStore::new(config.max_size, config.ttl)));
82 Self {
83 inner,
84 config,
85 store,
86 }
87 }
88}
89
90impl<S, Req, K, Resp> Clone for Cache<S, Req, K, Resp>
91where
92 S: Clone,
93{
94 fn clone(&self) -> Self {
95 Self {
96 inner: self.inner.clone(),
97 config: Arc::clone(&self.config),
98 store: Arc::clone(&self.store),
99 }
100 }
101}
102
103impl<S, Req, K> Service<Req> for Cache<S, Req, K, S::Response>
104where
105 S: Service<Req>,
106 S::Response: Clone + Send + 'static,
107 K: Hash + Eq + Clone + Send + 'static,
108 Req: Send + 'static,
109 S::Future: Send + 'static,
110{
111 type Response = S::Response;
112 type Error = CacheError<S::Error>;
113 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
114
115 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116 self.inner.poll_ready(cx).map_err(CacheError::Inner)
117 }
118
119 fn call(&mut self, req: Req) -> Self::Future {
120 let key = (self.config.key_extractor)(&req);
121
122 let cached = {
124 let mut store = self.store.lock().unwrap();
125 store.get(&key)
126 };
127
128 if let Some(response) = cached {
129 let event = CacheEvent::Hit {
131 pattern_name: self.config.name.clone(),
132 timestamp: Instant::now(),
133 };
134 self.config.event_listeners.emit(&event);
135 return Box::pin(async move { Ok(response) });
136 }
137
138 let miss_event = CacheEvent::Miss {
140 pattern_name: self.config.name.clone(),
141 timestamp: Instant::now(),
142 };
143 self.config.event_listeners.emit(&miss_event);
144
145 let future = self.inner.call(req);
146 let store = Arc::clone(&self.store);
147 let config = Arc::clone(&self.config);
148
149 Box::pin(async move {
150 let response = future.await.map_err(CacheError::Inner)?;
151
152 let was_evicted = {
154 let mut store = store.lock().unwrap();
155 let was_full = store.len() >= config.max_size;
156 store.insert(key, response.clone());
157 was_full
158 };
159
160 if was_evicted {
161 let event = CacheEvent::Eviction {
162 pattern_name: config.name.clone(),
163 timestamp: Instant::now(),
164 };
165 config.event_listeners.emit(&event);
166 }
167
168 Ok(response)
169 })
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use std::sync::atomic::{AtomicUsize, Ordering};
177 use std::time::Duration;
178 use tower::service_fn;
179 use tower::Layer;
180 use tower::ServiceExt;
181
182 #[tokio::test]
183 async fn cache_hit_returns_cached_response() {
184 let call_count = Arc::new(AtomicUsize::new(0));
185 let cc = Arc::clone(&call_count);
186
187 let service = service_fn(move |req: String| {
188 let cc = Arc::clone(&cc);
189 async move {
190 cc.fetch_add(1, Ordering::SeqCst);
191 Ok::<_, std::io::Error>(format!("Response: {}", req))
192 }
193 });
194
195 let layer = CacheLayer::builder()
196 .max_size(10)
197 .key_extractor(|req: &String| req.clone())
198 .build();
199
200 let mut service = layer.layer(service);
201
202 let response1 = service
204 .ready()
205 .await
206 .unwrap()
207 .call("test".to_string())
208 .await
209 .unwrap();
210 assert_eq!(response1, "Response: test");
211 assert_eq!(call_count.load(Ordering::SeqCst), 1);
212
213 let response2 = service
215 .ready()
216 .await
217 .unwrap()
218 .call("test".to_string())
219 .await
220 .unwrap();
221 assert_eq!(response2, "Response: test");
222 assert_eq!(call_count.load(Ordering::SeqCst), 1); }
224
225 #[tokio::test]
226 async fn cache_miss_calls_inner_service() {
227 let service = service_fn(|req: String| async move {
228 Ok::<_, std::io::Error>(format!("Response: {}", req))
229 });
230
231 let layer = CacheLayer::builder()
232 .max_size(10)
233 .key_extractor(|req: &String| req.clone())
234 .build();
235
236 let mut service = layer.layer(service);
237
238 let response = service
239 .ready()
240 .await
241 .unwrap()
242 .call("test".to_string())
243 .await
244 .unwrap();
245 assert_eq!(response, "Response: test");
246 }
247
248 #[tokio::test]
249 async fn different_keys_not_cached_together() {
250 let call_count = Arc::new(AtomicUsize::new(0));
251 let cc = Arc::clone(&call_count);
252
253 let service = service_fn(move |req: String| {
254 let cc = Arc::clone(&cc);
255 async move {
256 cc.fetch_add(1, Ordering::SeqCst);
257 Ok::<_, std::io::Error>(format!("Response: {}", req))
258 }
259 });
260
261 let layer = CacheLayer::builder()
262 .max_size(10)
263 .key_extractor(|req: &String| req.clone())
264 .build();
265
266 let mut service = layer.layer(service);
267
268 service
269 .ready()
270 .await
271 .unwrap()
272 .call("test1".to_string())
273 .await
274 .unwrap();
275 service
276 .ready()
277 .await
278 .unwrap()
279 .call("test2".to_string())
280 .await
281 .unwrap();
282
283 assert_eq!(call_count.load(Ordering::SeqCst), 2);
284 }
285
286 #[tokio::test]
287 async fn ttl_expiration_causes_cache_miss() {
288 let call_count = Arc::new(AtomicUsize::new(0));
289 let cc = Arc::clone(&call_count);
290
291 let service = service_fn(move |req: String| {
292 let cc = Arc::clone(&cc);
293 async move {
294 cc.fetch_add(1, Ordering::SeqCst);
295 Ok::<_, std::io::Error>(format!("Response: {}", req))
296 }
297 });
298
299 let layer = CacheLayer::builder()
300 .max_size(10)
301 .ttl(Duration::from_millis(50))
302 .key_extractor(|req: &String| req.clone())
303 .build();
304
305 let mut service = layer.layer(service);
306
307 service
308 .ready()
309 .await
310 .unwrap()
311 .call("test".to_string())
312 .await
313 .unwrap();
314 assert_eq!(call_count.load(Ordering::SeqCst), 1);
315
316 tokio::time::sleep(Duration::from_millis(100)).await;
318
319 service
320 .ready()
321 .await
322 .unwrap()
323 .call("test".to_string())
324 .await
325 .unwrap();
326 assert_eq!(call_count.load(Ordering::SeqCst), 2); }
328
329 #[tokio::test]
330 async fn lru_eviction_removes_least_recently_used() {
331 let service = service_fn(|req: String| async move {
332 Ok::<_, std::io::Error>(format!("Response: {}", req))
333 });
334
335 let layer = CacheLayer::builder()
336 .max_size(2)
337 .key_extractor(|req: &String| req.clone())
338 .build();
339
340 let mut service = layer.layer(service);
341
342 service
344 .ready()
345 .await
346 .unwrap()
347 .call("key1".to_string())
348 .await
349 .unwrap();
350 service
351 .ready()
352 .await
353 .unwrap()
354 .call("key2".to_string())
355 .await
356 .unwrap();
357
358 service
360 .ready()
361 .await
362 .unwrap()
363 .call("key3".to_string())
364 .await
365 .unwrap();
366
367 let call_count = Arc::new(AtomicUsize::new(0));
369 let cc = Arc::clone(&call_count);
370
371 let service2 = service_fn(move |req: String| {
372 let cc = Arc::clone(&cc);
373 async move {
374 cc.fetch_add(1, Ordering::SeqCst);
375 Ok::<_, std::io::Error>(format!("Response: {}", req))
376 }
377 });
378
379 let mut service2 = layer.layer(service2);
380
381 service2
383 .ready()
384 .await
385 .unwrap()
386 .call("key1".to_string())
387 .await
388 .unwrap();
389 assert_eq!(call_count.load(Ordering::SeqCst), 1);
390 }
391
392 #[tokio::test]
393 async fn event_listeners_called() {
394 let hit_count = Arc::new(AtomicUsize::new(0));
395 let miss_count = Arc::new(AtomicUsize::new(0));
396 let eviction_count = Arc::new(AtomicUsize::new(0));
397
398 let hc = Arc::clone(&hit_count);
399 let mc = Arc::clone(&miss_count);
400 let ec = Arc::clone(&eviction_count);
401
402 let service = service_fn(|req: String| async move {
403 Ok::<_, std::io::Error>(format!("Response: {}", req))
404 });
405
406 let layer = CacheLayer::builder()
407 .max_size(1)
408 .key_extractor(|req: &String| req.clone())
409 .on_hit(move || {
410 hc.fetch_add(1, Ordering::SeqCst);
411 })
412 .on_miss(move || {
413 mc.fetch_add(1, Ordering::SeqCst);
414 })
415 .on_eviction(move || {
416 ec.fetch_add(1, Ordering::SeqCst);
417 })
418 .build();
419
420 let mut service = layer.layer(service);
421
422 service
424 .ready()
425 .await
426 .unwrap()
427 .call("test".to_string())
428 .await
429 .unwrap();
430 assert_eq!(miss_count.load(Ordering::SeqCst), 1);
431 assert_eq!(hit_count.load(Ordering::SeqCst), 0);
432
433 service
435 .ready()
436 .await
437 .unwrap()
438 .call("test".to_string())
439 .await
440 .unwrap();
441 assert_eq!(hit_count.load(Ordering::SeqCst), 1);
442 assert_eq!(miss_count.load(Ordering::SeqCst), 1);
443
444 service
446 .ready()
447 .await
448 .unwrap()
449 .call("other".to_string())
450 .await
451 .unwrap();
452 assert_eq!(eviction_count.load(Ordering::SeqCst), 1);
453 }
454
455 #[tokio::test]
456 async fn errors_not_cached() {
457 let call_count = Arc::new(AtomicUsize::new(0));
458 let cc = Arc::clone(&call_count);
459
460 let service = service_fn(move |_req: String| {
461 let cc = Arc::clone(&cc);
462 async move {
463 cc.fetch_add(1, Ordering::SeqCst);
464 Err::<String, _>(std::io::Error::other("error"))
465 }
466 });
467
468 let layer = CacheLayer::builder()
469 .max_size(10)
470 .key_extractor(|req: &String| req.clone())
471 .build();
472
473 let mut service = layer.layer(service);
474
475 let _ = service
477 .ready()
478 .await
479 .unwrap()
480 .call("test".to_string())
481 .await;
482 assert_eq!(call_count.load(Ordering::SeqCst), 1);
483
484 let _ = service
486 .ready()
487 .await
488 .unwrap()
489 .call("test".to_string())
490 .await;
491 assert_eq!(call_count.load(Ordering::SeqCst), 2);
492 }
493}