Skip to main content

tower_resilience_cache/
lib.rs

1//! Response caching middleware for Tower services.
2//!
3//! This crate provides a Tower middleware for caching service responses,
4//! reducing load on downstream services by storing and reusing responses
5//! for identical requests.
6//!
7//! # Features
8//!
9//! - **Multiple Eviction Policies**: LRU, LFU, and FIFO eviction strategies
10//! - **TTL Support**: Optional time-to-live for cache entries
11//! - **Event System**: Observability through cache events (Hit, Miss, Eviction)
12//! - **Flexible Key Extraction**: User-defined key extraction from requests
13//!
14//! # Examples
15//!
16//! ```
17//! use tower_resilience_cache::{CacheLayer, EvictionPolicy};
18//! use tower::ServiceBuilder;
19//! use std::time::Duration;
20//!
21//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
22//! // Create a cache layer with LFU eviction
23//! let cache_layer = CacheLayer::builder()
24//!     .max_size(100)
25//!     .ttl(Duration::from_secs(60))
26//!     .eviction_policy(EvictionPolicy::Lfu)  // or Lru (default), Fifo
27//!     .key_extractor(|req: &String| req.clone())
28//!     .on_hit(|| println!("Cache hit!"))
29//!     .on_miss(|| println!("Cache miss!"))
30//!     .build();
31//!
32//! // Apply to a service
33//! let service = ServiceBuilder::new()
34//!     .layer(cache_layer)
35//!     .service(tower::service_fn(|req: String| async move {
36//!         Ok::<_, std::io::Error>(format!("Response: {}", req))
37//!     }));
38//! # Ok(())
39//! # }
40//! ```
41
42mod config;
43mod error;
44mod events;
45mod eviction;
46mod layer;
47mod store;
48
49pub use config::{CacheConfig, CacheConfigBuilder, KeyExtractor};
50pub use error::CacheError;
51pub use events::CacheEvent;
52pub use eviction::EvictionPolicy;
53pub use layer::CacheLayer;
54
55use futures::future::BoxFuture;
56use std::hash::Hash;
57use std::sync::{Arc, Mutex};
58use std::task::{Context, Poll};
59use std::time::Instant;
60use store::CacheStore;
61use tower::Service;
62
63#[cfg(feature = "metrics")]
64use metrics::{counter, describe_counter, describe_gauge, gauge};
65
66#[cfg(feature = "tracing")]
67use tracing::{debug, info};
68
69/// A Tower [`Service`] that caches responses.
70///
71/// This service wraps an inner service and caches successful responses.
72/// When a request comes in, the cache checks if a valid cached response
73/// exists. If so, it returns the cached value immediately without calling
74/// the inner service.
75///
76/// Responses must implement `Clone` to be cacheable.
77pub struct Cache<S, Req, K, Resp> {
78    inner: S,
79    config: Arc<CacheConfig<Req, K>>,
80    store: Arc<Mutex<CacheStore<K, Resp>>>,
81}
82
83impl<S, Req, K, Resp> Cache<S, Req, K, Resp>
84where
85    K: Hash + Eq + Clone + Send + 'static,
86    Resp: Clone + Send + 'static,
87{
88    /// Creates a new `Cache` wrapping the given service.
89    pub fn new(inner: S, config: Arc<CacheConfig<Req, K>>) -> Self {
90        #[cfg(feature = "metrics")]
91        {
92            describe_counter!(
93                "cache_requests_total",
94                "Total number of cache requests (hits and misses)"
95            );
96            describe_counter!("cache_evictions_total", "Total number of cache evictions");
97            describe_gauge!("cache_size", "Current number of entries in the cache");
98        }
99
100        let store = Arc::new(Mutex::new(CacheStore::new(
101            config.max_size,
102            config.ttl,
103            config.eviction_policy,
104        )));
105        Self {
106            inner,
107            config,
108            store,
109        }
110    }
111}
112
113impl<S, Req, K, Resp> Clone for Cache<S, Req, K, Resp>
114where
115    S: Clone,
116{
117    fn clone(&self) -> Self {
118        Self {
119            inner: self.inner.clone(),
120            config: Arc::clone(&self.config),
121            store: Arc::clone(&self.store),
122        }
123    }
124}
125
126impl<S, Req, K> Service<Req> for Cache<S, Req, K, S::Response>
127where
128    S: Service<Req>,
129    S::Response: Clone + Send + 'static,
130    K: Hash + Eq + Clone + Send + 'static,
131    Req: Send + 'static,
132    S::Future: Send + 'static,
133{
134    type Response = S::Response;
135    type Error = CacheError<S::Error>;
136    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
137
138    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139        self.inner.poll_ready(cx).map_err(CacheError::Inner)
140    }
141
142    fn call(&mut self, req: Req) -> Self::Future {
143        let key = (self.config.key_extractor)(&req);
144        let cache_name = self.config.name.clone();
145
146        // Check cache first
147        let cached = {
148            let mut store = self.store.lock().unwrap();
149            store.get(&key)
150        };
151
152        if let Some(response) = cached {
153            // Cache hit
154            #[cfg(feature = "metrics")]
155            {
156                counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "hit")
157                    .increment(1);
158            }
159
160            #[cfg(feature = "tracing")]
161            debug!(cache = %cache_name, "Cache hit");
162
163            let event = CacheEvent::Hit {
164                pattern_name: cache_name,
165                timestamp: Instant::now(),
166            };
167            self.config.event_listeners.emit(&event);
168            return Box::pin(async move { Ok(response) });
169        }
170
171        // Cache miss
172        #[cfg(feature = "metrics")]
173        {
174            counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "miss")
175                .increment(1);
176        }
177
178        #[cfg(feature = "tracing")]
179        debug!(cache = %cache_name, "Cache miss");
180
181        let miss_event = CacheEvent::Miss {
182            pattern_name: cache_name.clone(),
183            timestamp: Instant::now(),
184        };
185        self.config.event_listeners.emit(&miss_event);
186
187        let future = self.inner.call(req);
188        let store = Arc::clone(&self.store);
189        let config = Arc::clone(&self.config);
190
191        Box::pin(async move {
192            let response = future.await.map_err(CacheError::Inner)?;
193
194            // Store successful response in cache
195            let was_evicted = {
196                let mut store = store.lock().unwrap();
197                let was_full = store.len() >= config.max_size;
198                store.insert(key, response.clone());
199
200                // Update cache size gauge
201                #[cfg(feature = "metrics")]
202                {
203                    let new_size = store.len();
204                    gauge!("cache_size", "cache" => config.name.clone()).set(new_size as f64);
205                }
206
207                was_full
208            };
209
210            if was_evicted {
211                #[cfg(feature = "metrics")]
212                {
213                    counter!("cache_evictions_total", "cache" => config.name.clone()).increment(1);
214                }
215
216                #[cfg(feature = "tracing")]
217                info!(cache = %config.name, "Cache eviction occurred");
218
219                let event = CacheEvent::Eviction {
220                    pattern_name: config.name.clone(),
221                    timestamp: Instant::now(),
222                };
223                config.event_listeners.emit(&event);
224            }
225
226            Ok(response)
227        })
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use std::sync::atomic::{AtomicUsize, Ordering};
235    use std::time::Duration;
236    use tower::service_fn;
237    use tower::Layer;
238    use tower::ServiceExt;
239
240    #[tokio::test]
241    async fn cache_hit_returns_cached_response() {
242        let call_count = Arc::new(AtomicUsize::new(0));
243        let cc = Arc::clone(&call_count);
244
245        let service = service_fn(move |req: String| {
246            let cc = Arc::clone(&cc);
247            async move {
248                cc.fetch_add(1, Ordering::SeqCst);
249                Ok::<_, std::io::Error>(format!("Response: {}", req))
250            }
251        });
252
253        let layer = CacheLayer::builder()
254            .max_size(10)
255            .key_extractor(|req: &String| req.clone())
256            .build();
257
258        let mut service = layer.layer(service);
259
260        // First call - cache miss
261        let response1 = service
262            .ready()
263            .await
264            .unwrap()
265            .call("test".to_string())
266            .await
267            .unwrap();
268        assert_eq!(response1, "Response: test");
269        assert_eq!(call_count.load(Ordering::SeqCst), 1);
270
271        // Second call - cache hit
272        let response2 = service
273            .ready()
274            .await
275            .unwrap()
276            .call("test".to_string())
277            .await
278            .unwrap();
279        assert_eq!(response2, "Response: test");
280        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Not called again
281    }
282
283    #[tokio::test]
284    async fn cache_miss_calls_inner_service() {
285        let service = service_fn(|req: String| async move {
286            Ok::<_, std::io::Error>(format!("Response: {}", req))
287        });
288
289        let layer = CacheLayer::builder()
290            .max_size(10)
291            .key_extractor(|req: &String| req.clone())
292            .build();
293
294        let mut service = layer.layer(service);
295
296        let response = service
297            .ready()
298            .await
299            .unwrap()
300            .call("test".to_string())
301            .await
302            .unwrap();
303        assert_eq!(response, "Response: test");
304    }
305
306    #[tokio::test]
307    async fn different_keys_not_cached_together() {
308        let call_count = Arc::new(AtomicUsize::new(0));
309        let cc = Arc::clone(&call_count);
310
311        let service = service_fn(move |req: String| {
312            let cc = Arc::clone(&cc);
313            async move {
314                cc.fetch_add(1, Ordering::SeqCst);
315                Ok::<_, std::io::Error>(format!("Response: {}", req))
316            }
317        });
318
319        let layer = CacheLayer::builder()
320            .max_size(10)
321            .key_extractor(|req: &String| req.clone())
322            .build();
323
324        let mut service = layer.layer(service);
325
326        service
327            .ready()
328            .await
329            .unwrap()
330            .call("test1".to_string())
331            .await
332            .unwrap();
333        service
334            .ready()
335            .await
336            .unwrap()
337            .call("test2".to_string())
338            .await
339            .unwrap();
340
341        assert_eq!(call_count.load(Ordering::SeqCst), 2);
342    }
343
344    #[tokio::test]
345    async fn ttl_expiration_causes_cache_miss() {
346        let call_count = Arc::new(AtomicUsize::new(0));
347        let cc = Arc::clone(&call_count);
348
349        let service = service_fn(move |req: String| {
350            let cc = Arc::clone(&cc);
351            async move {
352                cc.fetch_add(1, Ordering::SeqCst);
353                Ok::<_, std::io::Error>(format!("Response: {}", req))
354            }
355        });
356
357        let layer = CacheLayer::builder()
358            .max_size(10)
359            .ttl(Duration::from_millis(50))
360            .key_extractor(|req: &String| req.clone())
361            .build();
362
363        let mut service = layer.layer(service);
364
365        service
366            .ready()
367            .await
368            .unwrap()
369            .call("test".to_string())
370            .await
371            .unwrap();
372        assert_eq!(call_count.load(Ordering::SeqCst), 1);
373
374        // Wait for TTL to expire
375        tokio::time::sleep(Duration::from_millis(100)).await;
376
377        service
378            .ready()
379            .await
380            .unwrap()
381            .call("test".to_string())
382            .await
383            .unwrap();
384        assert_eq!(call_count.load(Ordering::SeqCst), 2); // Called again
385    }
386
387    #[tokio::test]
388    async fn lru_eviction_removes_least_recently_used() {
389        let service = service_fn(|req: String| async move {
390            Ok::<_, std::io::Error>(format!("Response: {}", req))
391        });
392
393        let layer = CacheLayer::builder()
394            .max_size(2)
395            .key_extractor(|req: &String| req.clone())
396            .build();
397
398        let mut service = layer.layer(service);
399
400        // Fill cache with 2 items
401        service
402            .ready()
403            .await
404            .unwrap()
405            .call("key1".to_string())
406            .await
407            .unwrap();
408        service
409            .ready()
410            .await
411            .unwrap()
412            .call("key2".to_string())
413            .await
414            .unwrap();
415
416        // Add third item, should evict key1
417        service
418            .ready()
419            .await
420            .unwrap()
421            .call("key3".to_string())
422            .await
423            .unwrap();
424
425        // Verify cache state by checking call counts
426        let call_count = Arc::new(AtomicUsize::new(0));
427        let cc = Arc::clone(&call_count);
428
429        let service2 = service_fn(move |req: String| {
430            let cc = Arc::clone(&cc);
431            async move {
432                cc.fetch_add(1, Ordering::SeqCst);
433                Ok::<_, std::io::Error>(format!("Response: {}", req))
434            }
435        });
436
437        let mut service2 = layer.layer(service2);
438
439        // key1 should be evicted (cache miss)
440        service2
441            .ready()
442            .await
443            .unwrap()
444            .call("key1".to_string())
445            .await
446            .unwrap();
447        assert_eq!(call_count.load(Ordering::SeqCst), 1);
448    }
449
450    #[tokio::test]
451    async fn event_listeners_called() {
452        let hit_count = Arc::new(AtomicUsize::new(0));
453        let miss_count = Arc::new(AtomicUsize::new(0));
454        let eviction_count = Arc::new(AtomicUsize::new(0));
455
456        let hc = Arc::clone(&hit_count);
457        let mc = Arc::clone(&miss_count);
458        let ec = Arc::clone(&eviction_count);
459
460        let service = service_fn(|req: String| async move {
461            Ok::<_, std::io::Error>(format!("Response: {}", req))
462        });
463
464        let layer = CacheLayer::builder()
465            .max_size(1)
466            .key_extractor(|req: &String| req.clone())
467            .on_hit(move || {
468                hc.fetch_add(1, Ordering::SeqCst);
469            })
470            .on_miss(move || {
471                mc.fetch_add(1, Ordering::SeqCst);
472            })
473            .on_eviction(move || {
474                ec.fetch_add(1, Ordering::SeqCst);
475            })
476            .build();
477
478        let mut service = layer.layer(service);
479
480        // First call - miss
481        service
482            .ready()
483            .await
484            .unwrap()
485            .call("test".to_string())
486            .await
487            .unwrap();
488        assert_eq!(miss_count.load(Ordering::SeqCst), 1);
489        assert_eq!(hit_count.load(Ordering::SeqCst), 0);
490
491        // Second call - hit
492        service
493            .ready()
494            .await
495            .unwrap()
496            .call("test".to_string())
497            .await
498            .unwrap();
499        assert_eq!(hit_count.load(Ordering::SeqCst), 1);
500        assert_eq!(miss_count.load(Ordering::SeqCst), 1);
501
502        // Third call with different key - eviction
503        service
504            .ready()
505            .await
506            .unwrap()
507            .call("other".to_string())
508            .await
509            .unwrap();
510        assert_eq!(eviction_count.load(Ordering::SeqCst), 1);
511    }
512
513    #[tokio::test]
514    async fn errors_not_cached() {
515        let call_count = Arc::new(AtomicUsize::new(0));
516        let cc = Arc::clone(&call_count);
517
518        let service = service_fn(move |_req: String| {
519            let cc = Arc::clone(&cc);
520            async move {
521                cc.fetch_add(1, Ordering::SeqCst);
522                Err::<String, _>(std::io::Error::other("error"))
523            }
524        });
525
526        let layer = CacheLayer::builder()
527            .max_size(10)
528            .key_extractor(|req: &String| req.clone())
529            .build();
530
531        let mut service = layer.layer(service);
532
533        // First call - error
534        let _ = service
535            .ready()
536            .await
537            .unwrap()
538            .call("test".to_string())
539            .await;
540        assert_eq!(call_count.load(Ordering::SeqCst), 1);
541
542        // Second call - should call inner again (error not cached)
543        let _ = service
544            .ready()
545            .await
546            .unwrap()
547            .call("test".to_string())
548            .await;
549        assert_eq!(call_count.load(Ordering::SeqCst), 2);
550    }
551}