1mod config;
42mod error;
43mod events;
44mod layer;
45mod store;
46
47pub use config::{CacheConfig, CacheConfigBuilder, KeyExtractor};
48pub use error::CacheError;
49pub use events::CacheEvent;
50pub use layer::CacheLayer;
51
52use futures::future::BoxFuture;
53use std::hash::Hash;
54use std::sync::{Arc, Mutex};
55use std::task::{Context, Poll};
56use std::time::Instant;
57use store::CacheStore;
58use tower::Service;
59
60#[cfg(feature = "metrics")]
61use metrics::{counter, describe_counter, describe_gauge, gauge};
62
63#[cfg(feature = "tracing")]
64use tracing::{debug, info};
65
66pub struct Cache<S, Req, K, Resp> {
75 inner: S,
76 config: Arc<CacheConfig<Req, K>>,
77 store: Arc<Mutex<CacheStore<K, Resp>>>,
78}
79
80impl<S, Req, K, Resp> Cache<S, Req, K, Resp>
81where
82 K: Hash + Eq,
83 Resp: Clone,
84{
85 pub fn new(inner: S, config: Arc<CacheConfig<Req, K>>) -> Self {
87 #[cfg(feature = "metrics")]
88 {
89 describe_counter!(
90 "cache_requests_total",
91 "Total number of cache requests (hits and misses)"
92 );
93 describe_counter!("cache_evictions_total", "Total number of cache evictions");
94 describe_gauge!("cache_size", "Current number of entries in the cache");
95 }
96
97 let store = Arc::new(Mutex::new(CacheStore::new(config.max_size, config.ttl)));
98 Self {
99 inner,
100 config,
101 store,
102 }
103 }
104}
105
106impl<S, Req, K, Resp> Clone for Cache<S, Req, K, Resp>
107where
108 S: Clone,
109{
110 fn clone(&self) -> Self {
111 Self {
112 inner: self.inner.clone(),
113 config: Arc::clone(&self.config),
114 store: Arc::clone(&self.store),
115 }
116 }
117}
118
119impl<S, Req, K> Service<Req> for Cache<S, Req, K, S::Response>
120where
121 S: Service<Req>,
122 S::Response: Clone + Send + 'static,
123 K: Hash + Eq + Clone + Send + 'static,
124 Req: Send + 'static,
125 S::Future: Send + 'static,
126{
127 type Response = S::Response;
128 type Error = CacheError<S::Error>;
129 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
130
131 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
132 self.inner.poll_ready(cx).map_err(CacheError::Inner)
133 }
134
135 fn call(&mut self, req: Req) -> Self::Future {
136 let key = (self.config.key_extractor)(&req);
137 let cache_name = self.config.name.clone();
138
139 let cached = {
141 let mut store = self.store.lock().unwrap();
142 store.get(&key)
143 };
144
145 if let Some(response) = cached {
146 #[cfg(feature = "metrics")]
148 {
149 counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "hit")
150 .increment(1);
151 }
152
153 #[cfg(feature = "tracing")]
154 debug!(cache = %cache_name, "Cache hit");
155
156 let event = CacheEvent::Hit {
157 pattern_name: cache_name,
158 timestamp: Instant::now(),
159 };
160 self.config.event_listeners.emit(&event);
161 return Box::pin(async move { Ok(response) });
162 }
163
164 #[cfg(feature = "metrics")]
166 {
167 counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "miss")
168 .increment(1);
169 }
170
171 #[cfg(feature = "tracing")]
172 debug!(cache = %cache_name, "Cache miss");
173
174 let miss_event = CacheEvent::Miss {
175 pattern_name: cache_name.clone(),
176 timestamp: Instant::now(),
177 };
178 self.config.event_listeners.emit(&miss_event);
179
180 let future = self.inner.call(req);
181 let store = Arc::clone(&self.store);
182 let config = Arc::clone(&self.config);
183
184 Box::pin(async move {
185 let response = future.await.map_err(CacheError::Inner)?;
186
187 let was_evicted = {
189 let mut store = store.lock().unwrap();
190 let was_full = store.len() >= config.max_size;
191 store.insert(key, response.clone());
192
193 #[cfg(feature = "metrics")]
195 {
196 let new_size = store.len();
197 gauge!("cache_size", "cache" => config.name.clone()).set(new_size as f64);
198 }
199
200 was_full
201 };
202
203 if was_evicted {
204 #[cfg(feature = "metrics")]
205 {
206 counter!("cache_evictions_total", "cache" => config.name.clone()).increment(1);
207 }
208
209 #[cfg(feature = "tracing")]
210 info!(cache = %config.name, "Cache eviction occurred");
211
212 let event = CacheEvent::Eviction {
213 pattern_name: config.name.clone(),
214 timestamp: Instant::now(),
215 };
216 config.event_listeners.emit(&event);
217 }
218
219 Ok(response)
220 })
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use std::sync::atomic::{AtomicUsize, Ordering};
228 use std::time::Duration;
229 use tower::service_fn;
230 use tower::Layer;
231 use tower::ServiceExt;
232
233 #[tokio::test]
234 async fn cache_hit_returns_cached_response() {
235 let call_count = Arc::new(AtomicUsize::new(0));
236 let cc = Arc::clone(&call_count);
237
238 let service = service_fn(move |req: String| {
239 let cc = Arc::clone(&cc);
240 async move {
241 cc.fetch_add(1, Ordering::SeqCst);
242 Ok::<_, std::io::Error>(format!("Response: {}", req))
243 }
244 });
245
246 let layer = CacheLayer::builder()
247 .max_size(10)
248 .key_extractor(|req: &String| req.clone())
249 .build();
250
251 let mut service = layer.layer(service);
252
253 let response1 = service
255 .ready()
256 .await
257 .unwrap()
258 .call("test".to_string())
259 .await
260 .unwrap();
261 assert_eq!(response1, "Response: test");
262 assert_eq!(call_count.load(Ordering::SeqCst), 1);
263
264 let response2 = service
266 .ready()
267 .await
268 .unwrap()
269 .call("test".to_string())
270 .await
271 .unwrap();
272 assert_eq!(response2, "Response: test");
273 assert_eq!(call_count.load(Ordering::SeqCst), 1); }
275
276 #[tokio::test]
277 async fn cache_miss_calls_inner_service() {
278 let service = service_fn(|req: String| async move {
279 Ok::<_, std::io::Error>(format!("Response: {}", req))
280 });
281
282 let layer = CacheLayer::builder()
283 .max_size(10)
284 .key_extractor(|req: &String| req.clone())
285 .build();
286
287 let mut service = layer.layer(service);
288
289 let response = service
290 .ready()
291 .await
292 .unwrap()
293 .call("test".to_string())
294 .await
295 .unwrap();
296 assert_eq!(response, "Response: test");
297 }
298
299 #[tokio::test]
300 async fn different_keys_not_cached_together() {
301 let call_count = Arc::new(AtomicUsize::new(0));
302 let cc = Arc::clone(&call_count);
303
304 let service = service_fn(move |req: String| {
305 let cc = Arc::clone(&cc);
306 async move {
307 cc.fetch_add(1, Ordering::SeqCst);
308 Ok::<_, std::io::Error>(format!("Response: {}", req))
309 }
310 });
311
312 let layer = CacheLayer::builder()
313 .max_size(10)
314 .key_extractor(|req: &String| req.clone())
315 .build();
316
317 let mut service = layer.layer(service);
318
319 service
320 .ready()
321 .await
322 .unwrap()
323 .call("test1".to_string())
324 .await
325 .unwrap();
326 service
327 .ready()
328 .await
329 .unwrap()
330 .call("test2".to_string())
331 .await
332 .unwrap();
333
334 assert_eq!(call_count.load(Ordering::SeqCst), 2);
335 }
336
337 #[tokio::test]
338 async fn ttl_expiration_causes_cache_miss() {
339 let call_count = Arc::new(AtomicUsize::new(0));
340 let cc = Arc::clone(&call_count);
341
342 let service = service_fn(move |req: String| {
343 let cc = Arc::clone(&cc);
344 async move {
345 cc.fetch_add(1, Ordering::SeqCst);
346 Ok::<_, std::io::Error>(format!("Response: {}", req))
347 }
348 });
349
350 let layer = CacheLayer::builder()
351 .max_size(10)
352 .ttl(Duration::from_millis(50))
353 .key_extractor(|req: &String| req.clone())
354 .build();
355
356 let mut service = layer.layer(service);
357
358 service
359 .ready()
360 .await
361 .unwrap()
362 .call("test".to_string())
363 .await
364 .unwrap();
365 assert_eq!(call_count.load(Ordering::SeqCst), 1);
366
367 tokio::time::sleep(Duration::from_millis(100)).await;
369
370 service
371 .ready()
372 .await
373 .unwrap()
374 .call("test".to_string())
375 .await
376 .unwrap();
377 assert_eq!(call_count.load(Ordering::SeqCst), 2); }
379
380 #[tokio::test]
381 async fn lru_eviction_removes_least_recently_used() {
382 let service = service_fn(|req: String| async move {
383 Ok::<_, std::io::Error>(format!("Response: {}", req))
384 });
385
386 let layer = CacheLayer::builder()
387 .max_size(2)
388 .key_extractor(|req: &String| req.clone())
389 .build();
390
391 let mut service = layer.layer(service);
392
393 service
395 .ready()
396 .await
397 .unwrap()
398 .call("key1".to_string())
399 .await
400 .unwrap();
401 service
402 .ready()
403 .await
404 .unwrap()
405 .call("key2".to_string())
406 .await
407 .unwrap();
408
409 service
411 .ready()
412 .await
413 .unwrap()
414 .call("key3".to_string())
415 .await
416 .unwrap();
417
418 let call_count = Arc::new(AtomicUsize::new(0));
420 let cc = Arc::clone(&call_count);
421
422 let service2 = service_fn(move |req: String| {
423 let cc = Arc::clone(&cc);
424 async move {
425 cc.fetch_add(1, Ordering::SeqCst);
426 Ok::<_, std::io::Error>(format!("Response: {}", req))
427 }
428 });
429
430 let mut service2 = layer.layer(service2);
431
432 service2
434 .ready()
435 .await
436 .unwrap()
437 .call("key1".to_string())
438 .await
439 .unwrap();
440 assert_eq!(call_count.load(Ordering::SeqCst), 1);
441 }
442
443 #[tokio::test]
444 async fn event_listeners_called() {
445 let hit_count = Arc::new(AtomicUsize::new(0));
446 let miss_count = Arc::new(AtomicUsize::new(0));
447 let eviction_count = Arc::new(AtomicUsize::new(0));
448
449 let hc = Arc::clone(&hit_count);
450 let mc = Arc::clone(&miss_count);
451 let ec = Arc::clone(&eviction_count);
452
453 let service = service_fn(|req: String| async move {
454 Ok::<_, std::io::Error>(format!("Response: {}", req))
455 });
456
457 let layer = CacheLayer::builder()
458 .max_size(1)
459 .key_extractor(|req: &String| req.clone())
460 .on_hit(move || {
461 hc.fetch_add(1, Ordering::SeqCst);
462 })
463 .on_miss(move || {
464 mc.fetch_add(1, Ordering::SeqCst);
465 })
466 .on_eviction(move || {
467 ec.fetch_add(1, Ordering::SeqCst);
468 })
469 .build();
470
471 let mut service = layer.layer(service);
472
473 service
475 .ready()
476 .await
477 .unwrap()
478 .call("test".to_string())
479 .await
480 .unwrap();
481 assert_eq!(miss_count.load(Ordering::SeqCst), 1);
482 assert_eq!(hit_count.load(Ordering::SeqCst), 0);
483
484 service
486 .ready()
487 .await
488 .unwrap()
489 .call("test".to_string())
490 .await
491 .unwrap();
492 assert_eq!(hit_count.load(Ordering::SeqCst), 1);
493 assert_eq!(miss_count.load(Ordering::SeqCst), 1);
494
495 service
497 .ready()
498 .await
499 .unwrap()
500 .call("other".to_string())
501 .await
502 .unwrap();
503 assert_eq!(eviction_count.load(Ordering::SeqCst), 1);
504 }
505
506 #[tokio::test]
507 async fn errors_not_cached() {
508 let call_count = Arc::new(AtomicUsize::new(0));
509 let cc = Arc::clone(&call_count);
510
511 let service = service_fn(move |_req: String| {
512 let cc = Arc::clone(&cc);
513 async move {
514 cc.fetch_add(1, Ordering::SeqCst);
515 Err::<String, _>(std::io::Error::other("error"))
516 }
517 });
518
519 let layer = CacheLayer::builder()
520 .max_size(10)
521 .key_extractor(|req: &String| req.clone())
522 .build();
523
524 let mut service = layer.layer(service);
525
526 let _ = service
528 .ready()
529 .await
530 .unwrap()
531 .call("test".to_string())
532 .await;
533 assert_eq!(call_count.load(Ordering::SeqCst), 1);
534
535 let _ = service
537 .ready()
538 .await
539 .unwrap()
540 .call("test".to_string())
541 .await;
542 assert_eq!(call_count.load(Ordering::SeqCst), 2);
543 }
544}