Skip to main content

tower_resilience_cache/
lib.rs

1//! Response caching middleware for Tower services.
2//!
3//! This crate provides a Tower middleware for caching service responses,
4//! reducing load on downstream services by storing and reusing responses
5//! for identical requests.
6//!
7//! # Features
8//!
9//! - **Multiple Eviction Policies**: LRU, LFU, and FIFO eviction strategies
10//! - **TTL Support**: Optional time-to-live for cache entries
11//! - **Event System**: Observability through cache events (Hit, Miss, Eviction)
12//! - **Flexible Key Extraction**: User-defined key extraction from requests
13//!
14//! # Examples
15//!
16//! ```
17//! use tower_resilience_cache::{CacheLayer, EvictionPolicy};
18//! use tower::ServiceBuilder;
19//! use std::time::Duration;
20//!
21//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
22//! // Create a cache layer with LFU eviction
23//! let cache_layer = CacheLayer::builder()
24//!     .max_size(100)
25//!     .ttl(Duration::from_secs(60))
26//!     .eviction_policy(EvictionPolicy::Lfu)  // or Lru (default), Fifo
27//!     .key_extractor(|req: &String| req.clone())
28//!     .on_hit(|| println!("Cache hit!"))
29//!     .on_miss(|| println!("Cache miss!"))
30//!     .build();
31//!
32//! // Apply to a service
33//! let service = ServiceBuilder::new()
34//!     .layer(cache_layer)
35//!     .service(tower::service_fn(|req: String| async move {
36//!         Ok::<_, std::io::Error>(format!("Response: {}", req))
37//!     }));
38//! # Ok(())
39//! # }
40//! ```
41
42mod config;
43mod error;
44mod events;
45mod eviction;
46mod layer;
47mod shared_layer;
48mod store;
49
50pub use config::{CacheConfig, CacheConfigBuilder, KeyExtractor};
51pub use error::CacheError;
52pub use events::CacheEvent;
53pub use eviction::EvictionPolicy;
54pub use layer::CacheLayer;
55pub use shared_layer::{SharedCacheConfigBuilder, SharedCacheLayer};
56
57use futures::future::BoxFuture;
58use std::hash::Hash;
59use std::sync::{Arc, Mutex};
60use std::task::{Context, Poll};
61use std::time::Instant;
62use store::CacheStore;
63use tower::Service;
64
65#[cfg(feature = "metrics")]
66use metrics::{counter, describe_counter, describe_gauge, gauge};
67
68#[cfg(feature = "tracing")]
69use tracing::{debug, info};
70
71/// A Tower [`Service`] that caches responses.
72///
73/// This service wraps an inner service and caches successful responses.
74/// When a request comes in, the cache checks if a valid cached response
75/// exists. If so, it returns the cached value immediately without calling
76/// the inner service.
77///
78/// Responses must implement `Clone` to be cacheable.
79pub struct Cache<S, Req, K, Resp> {
80    inner: S,
81    config: Arc<CacheConfig<Req, K>>,
82    store: Arc<Mutex<CacheStore<K, Resp>>>,
83}
84
85impl<S, Req, K, Resp> Cache<S, Req, K, Resp>
86where
87    K: Hash + Eq + Clone + Send + 'static,
88    Resp: Clone + Send + 'static,
89{
90    /// Creates a new `Cache` wrapping the given service.
91    pub fn new(inner: S, config: Arc<CacheConfig<Req, K>>) -> Self {
92        #[cfg(feature = "metrics")]
93        {
94            describe_counter!(
95                "cache_requests_total",
96                "Total number of cache requests (hits and misses)"
97            );
98            describe_counter!("cache_evictions_total", "Total number of cache evictions");
99            describe_gauge!("cache_size", "Current number of entries in the cache");
100        }
101
102        let store = Arc::new(Mutex::new(CacheStore::new(
103            config.max_size,
104            config.ttl,
105            config.eviction_policy,
106        )));
107        Self {
108            inner,
109            config,
110            store,
111        }
112    }
113
114    /// Creates a new `Cache` wrapping the given service with a pre-existing store.
115    ///
116    /// This is used by [`SharedCacheLayer`] to share the same cache store across
117    /// multiple services.
118    pub(crate) fn with_store(
119        inner: S,
120        config: Arc<CacheConfig<Req, K>>,
121        store: Arc<Mutex<CacheStore<K, Resp>>>,
122    ) -> Self {
123        #[cfg(feature = "metrics")]
124        {
125            describe_counter!(
126                "cache_requests_total",
127                "Total number of cache requests (hits and misses)"
128            );
129            describe_counter!("cache_evictions_total", "Total number of cache evictions");
130            describe_gauge!("cache_size", "Current number of entries in the cache");
131        }
132
133        Self {
134            inner,
135            config,
136            store,
137        }
138    }
139}
140
141impl<S, Req, K, Resp> Clone for Cache<S, Req, K, Resp>
142where
143    S: Clone,
144{
145    fn clone(&self) -> Self {
146        Self {
147            inner: self.inner.clone(),
148            config: Arc::clone(&self.config),
149            store: Arc::clone(&self.store),
150        }
151    }
152}
153
154impl<S, Req, K> Service<Req> for Cache<S, Req, K, S::Response>
155where
156    S: Service<Req>,
157    S::Response: Clone + Send + 'static,
158    K: Hash + Eq + Clone + Send + 'static,
159    Req: Send + 'static,
160    S::Future: Send + 'static,
161{
162    type Response = S::Response;
163    type Error = CacheError<S::Error>;
164    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
165
166    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
167        self.inner.poll_ready(cx).map_err(CacheError::Inner)
168    }
169
170    fn call(&mut self, req: Req) -> Self::Future {
171        let key = (self.config.key_extractor)(&req);
172        let cache_name = self.config.name.clone();
173
174        // Check cache first
175        let cached = {
176            let mut store = self.store.lock().unwrap();
177            store.get(&key)
178        };
179
180        if let Some(response) = cached {
181            // Cache hit
182            #[cfg(feature = "metrics")]
183            {
184                counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "hit")
185                    .increment(1);
186            }
187
188            #[cfg(feature = "tracing")]
189            debug!(cache = %cache_name, "Cache hit");
190
191            let event = CacheEvent::Hit {
192                pattern_name: cache_name,
193                timestamp: Instant::now(),
194            };
195            self.config.event_listeners.emit(&event);
196            return Box::pin(async move { Ok(response) });
197        }
198
199        // Cache miss
200        #[cfg(feature = "metrics")]
201        {
202            counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "miss")
203                .increment(1);
204        }
205
206        #[cfg(feature = "tracing")]
207        debug!(cache = %cache_name, "Cache miss");
208
209        let miss_event = CacheEvent::Miss {
210            pattern_name: cache_name.clone(),
211            timestamp: Instant::now(),
212        };
213        self.config.event_listeners.emit(&miss_event);
214
215        let future = self.inner.call(req);
216        let store = Arc::clone(&self.store);
217        let config = Arc::clone(&self.config);
218
219        Box::pin(async move {
220            let response = future.await.map_err(CacheError::Inner)?;
221
222            // Store successful response in cache
223            let was_evicted = {
224                let mut store = store.lock().unwrap();
225                let was_full = store.len() >= config.max_size;
226                store.insert(key, response.clone());
227
228                // Update cache size gauge
229                #[cfg(feature = "metrics")]
230                {
231                    let new_size = store.len();
232                    gauge!("cache_size", "cache" => config.name.clone()).set(new_size as f64);
233                }
234
235                was_full
236            };
237
238            if was_evicted {
239                #[cfg(feature = "metrics")]
240                {
241                    counter!("cache_evictions_total", "cache" => config.name.clone()).increment(1);
242                }
243
244                #[cfg(feature = "tracing")]
245                info!(cache = %config.name, "Cache eviction occurred");
246
247                let event = CacheEvent::Eviction {
248                    pattern_name: config.name.clone(),
249                    timestamp: Instant::now(),
250                };
251                config.event_listeners.emit(&event);
252            }
253
254            Ok(response)
255        })
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use std::sync::atomic::{AtomicUsize, Ordering};
263    use std::time::Duration;
264    use tower::service_fn;
265    use tower::Layer;
266    use tower::ServiceExt;
267
268    #[tokio::test]
269    async fn cache_hit_returns_cached_response() {
270        let call_count = Arc::new(AtomicUsize::new(0));
271        let cc = Arc::clone(&call_count);
272
273        let service = service_fn(move |req: String| {
274            let cc = Arc::clone(&cc);
275            async move {
276                cc.fetch_add(1, Ordering::SeqCst);
277                Ok::<_, std::io::Error>(format!("Response: {}", req))
278            }
279        });
280
281        let layer = CacheLayer::builder()
282            .max_size(10)
283            .key_extractor(|req: &String| req.clone())
284            .build();
285
286        let mut service = layer.layer(service);
287
288        // First call - cache miss
289        let response1 = service
290            .ready()
291            .await
292            .unwrap()
293            .call("test".to_string())
294            .await
295            .unwrap();
296        assert_eq!(response1, "Response: test");
297        assert_eq!(call_count.load(Ordering::SeqCst), 1);
298
299        // Second call - cache hit
300        let response2 = service
301            .ready()
302            .await
303            .unwrap()
304            .call("test".to_string())
305            .await
306            .unwrap();
307        assert_eq!(response2, "Response: test");
308        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Not called again
309    }
310
311    #[tokio::test]
312    async fn cache_miss_calls_inner_service() {
313        let service = service_fn(|req: String| async move {
314            Ok::<_, std::io::Error>(format!("Response: {}", req))
315        });
316
317        let layer = CacheLayer::builder()
318            .max_size(10)
319            .key_extractor(|req: &String| req.clone())
320            .build();
321
322        let mut service = layer.layer(service);
323
324        let response = service
325            .ready()
326            .await
327            .unwrap()
328            .call("test".to_string())
329            .await
330            .unwrap();
331        assert_eq!(response, "Response: test");
332    }
333
334    #[tokio::test]
335    async fn different_keys_not_cached_together() {
336        let call_count = Arc::new(AtomicUsize::new(0));
337        let cc = Arc::clone(&call_count);
338
339        let service = service_fn(move |req: String| {
340            let cc = Arc::clone(&cc);
341            async move {
342                cc.fetch_add(1, Ordering::SeqCst);
343                Ok::<_, std::io::Error>(format!("Response: {}", req))
344            }
345        });
346
347        let layer = CacheLayer::builder()
348            .max_size(10)
349            .key_extractor(|req: &String| req.clone())
350            .build();
351
352        let mut service = layer.layer(service);
353
354        service
355            .ready()
356            .await
357            .unwrap()
358            .call("test1".to_string())
359            .await
360            .unwrap();
361        service
362            .ready()
363            .await
364            .unwrap()
365            .call("test2".to_string())
366            .await
367            .unwrap();
368
369        assert_eq!(call_count.load(Ordering::SeqCst), 2);
370    }
371
372    #[tokio::test]
373    async fn ttl_expiration_causes_cache_miss() {
374        let call_count = Arc::new(AtomicUsize::new(0));
375        let cc = Arc::clone(&call_count);
376
377        let service = service_fn(move |req: String| {
378            let cc = Arc::clone(&cc);
379            async move {
380                cc.fetch_add(1, Ordering::SeqCst);
381                Ok::<_, std::io::Error>(format!("Response: {}", req))
382            }
383        });
384
385        let layer = CacheLayer::builder()
386            .max_size(10)
387            .ttl(Duration::from_millis(50))
388            .key_extractor(|req: &String| req.clone())
389            .build();
390
391        let mut service = layer.layer(service);
392
393        service
394            .ready()
395            .await
396            .unwrap()
397            .call("test".to_string())
398            .await
399            .unwrap();
400        assert_eq!(call_count.load(Ordering::SeqCst), 1);
401
402        // Wait for TTL to expire
403        tokio::time::sleep(Duration::from_millis(100)).await;
404
405        service
406            .ready()
407            .await
408            .unwrap()
409            .call("test".to_string())
410            .await
411            .unwrap();
412        assert_eq!(call_count.load(Ordering::SeqCst), 2); // Called again
413    }
414
415    #[tokio::test]
416    async fn lru_eviction_removes_least_recently_used() {
417        let service = service_fn(|req: String| async move {
418            Ok::<_, std::io::Error>(format!("Response: {}", req))
419        });
420
421        let layer = CacheLayer::builder()
422            .max_size(2)
423            .key_extractor(|req: &String| req.clone())
424            .build();
425
426        let mut service = layer.layer(service);
427
428        // Fill cache with 2 items
429        service
430            .ready()
431            .await
432            .unwrap()
433            .call("key1".to_string())
434            .await
435            .unwrap();
436        service
437            .ready()
438            .await
439            .unwrap()
440            .call("key2".to_string())
441            .await
442            .unwrap();
443
444        // Add third item, should evict key1
445        service
446            .ready()
447            .await
448            .unwrap()
449            .call("key3".to_string())
450            .await
451            .unwrap();
452
453        // Verify cache state by checking call counts
454        let call_count = Arc::new(AtomicUsize::new(0));
455        let cc = Arc::clone(&call_count);
456
457        let service2 = service_fn(move |req: String| {
458            let cc = Arc::clone(&cc);
459            async move {
460                cc.fetch_add(1, Ordering::SeqCst);
461                Ok::<_, std::io::Error>(format!("Response: {}", req))
462            }
463        });
464
465        let mut service2 = layer.layer(service2);
466
467        // key1 should be evicted (cache miss)
468        service2
469            .ready()
470            .await
471            .unwrap()
472            .call("key1".to_string())
473            .await
474            .unwrap();
475        assert_eq!(call_count.load(Ordering::SeqCst), 1);
476    }
477
478    #[tokio::test]
479    async fn event_listeners_called() {
480        let hit_count = Arc::new(AtomicUsize::new(0));
481        let miss_count = Arc::new(AtomicUsize::new(0));
482        let eviction_count = Arc::new(AtomicUsize::new(0));
483
484        let hc = Arc::clone(&hit_count);
485        let mc = Arc::clone(&miss_count);
486        let ec = Arc::clone(&eviction_count);
487
488        let service = service_fn(|req: String| async move {
489            Ok::<_, std::io::Error>(format!("Response: {}", req))
490        });
491
492        let layer = CacheLayer::builder()
493            .max_size(1)
494            .key_extractor(|req: &String| req.clone())
495            .on_hit(move || {
496                hc.fetch_add(1, Ordering::SeqCst);
497            })
498            .on_miss(move || {
499                mc.fetch_add(1, Ordering::SeqCst);
500            })
501            .on_eviction(move || {
502                ec.fetch_add(1, Ordering::SeqCst);
503            })
504            .build();
505
506        let mut service = layer.layer(service);
507
508        // First call - miss
509        service
510            .ready()
511            .await
512            .unwrap()
513            .call("test".to_string())
514            .await
515            .unwrap();
516        assert_eq!(miss_count.load(Ordering::SeqCst), 1);
517        assert_eq!(hit_count.load(Ordering::SeqCst), 0);
518
519        // Second call - hit
520        service
521            .ready()
522            .await
523            .unwrap()
524            .call("test".to_string())
525            .await
526            .unwrap();
527        assert_eq!(hit_count.load(Ordering::SeqCst), 1);
528        assert_eq!(miss_count.load(Ordering::SeqCst), 1);
529
530        // Third call with different key - eviction
531        service
532            .ready()
533            .await
534            .unwrap()
535            .call("other".to_string())
536            .await
537            .unwrap();
538        assert_eq!(eviction_count.load(Ordering::SeqCst), 1);
539    }
540
541    #[tokio::test]
542    async fn errors_not_cached() {
543        let call_count = Arc::new(AtomicUsize::new(0));
544        let cc = Arc::clone(&call_count);
545
546        let service = service_fn(move |_req: String| {
547            let cc = Arc::clone(&cc);
548            async move {
549                cc.fetch_add(1, Ordering::SeqCst);
550                Err::<String, _>(std::io::Error::other("error"))
551            }
552        });
553
554        let layer = CacheLayer::builder()
555            .max_size(10)
556            .key_extractor(|req: &String| req.clone())
557            .build();
558
559        let mut service = layer.layer(service);
560
561        // First call - error
562        let _ = service
563            .ready()
564            .await
565            .unwrap()
566            .call("test".to_string())
567            .await;
568        assert_eq!(call_count.load(Ordering::SeqCst), 1);
569
570        // Second call - should call inner again (error not cached)
571        let _ = service
572            .ready()
573            .await
574            .unwrap()
575            .call("test".to_string())
576            .await;
577        assert_eq!(call_count.load(Ordering::SeqCst), 2);
578    }
579}