tower_resilience_cache/
lib.rs

1//! Response caching middleware for Tower services.
2//!
3//! This crate provides a Tower middleware for caching service responses,
4//! reducing load on downstream services by storing and reusing responses
5//! for identical requests.
6//!
7//! # Features
8//!
9//! - **LRU Eviction**: Least Recently Used eviction policy
10//! - **TTL Support**: Optional time-to-live for cache entries
11//! - **Event System**: Observability through cache events (Hit, Miss, Eviction)
12//! - **Flexible Key Extraction**: User-defined key extraction from requests
13//!
14//! # Examples
15//!
16//! ```
17//! use tower_resilience_cache::CacheLayer;
18//! use tower::ServiceBuilder;
19//! use std::time::Duration;
20//!
21//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
22//! // Create a cache layer
23//! let cache_layer = CacheLayer::builder()
24//!     .max_size(100)
25//!     .ttl(Duration::from_secs(60))
26//!     .key_extractor(|req: &String| req.clone())
27//!     .on_hit(|| println!("Cache hit!"))
28//!     .on_miss(|| println!("Cache miss!"))
29//!     .build();
30//!
31//! // Apply to a service
32//! let service = ServiceBuilder::new()
33//!     .layer(cache_layer)
34//!     .service(tower::service_fn(|req: String| async move {
35//!         Ok::<_, std::io::Error>(format!("Response: {}", req))
36//!     }));
37//! # Ok(())
38//! # }
39//! ```
40
41mod config;
42mod error;
43mod events;
44mod layer;
45mod store;
46
47pub use config::{CacheConfig, CacheConfigBuilder, KeyExtractor};
48pub use error::CacheError;
49pub use events::CacheEvent;
50pub use layer::CacheLayer;
51
52use futures::future::BoxFuture;
53use std::hash::Hash;
54use std::sync::{Arc, Mutex};
55use std::task::{Context, Poll};
56use std::time::Instant;
57use store::CacheStore;
58use tower::Service;
59
60#[cfg(feature = "metrics")]
61use metrics::{counter, describe_counter, describe_gauge, gauge};
62
63#[cfg(feature = "tracing")]
64use tracing::{debug, info};
65
66/// A Tower [`Service`] that caches responses.
67///
68/// This service wraps an inner service and caches successful responses.
69/// When a request comes in, the cache checks if a valid cached response
70/// exists. If so, it returns the cached value immediately without calling
71/// the inner service.
72///
73/// Responses must implement `Clone` to be cacheable.
74pub struct Cache<S, Req, K, Resp> {
75    inner: S,
76    config: Arc<CacheConfig<Req, K>>,
77    store: Arc<Mutex<CacheStore<K, Resp>>>,
78}
79
80impl<S, Req, K, Resp> Cache<S, Req, K, Resp>
81where
82    K: Hash + Eq,
83    Resp: Clone,
84{
85    /// Creates a new `Cache` wrapping the given service.
86    pub fn new(inner: S, config: Arc<CacheConfig<Req, K>>) -> Self {
87        #[cfg(feature = "metrics")]
88        {
89            describe_counter!(
90                "cache_requests_total",
91                "Total number of cache requests (hits and misses)"
92            );
93            describe_counter!("cache_evictions_total", "Total number of cache evictions");
94            describe_gauge!("cache_size", "Current number of entries in the cache");
95        }
96
97        let store = Arc::new(Mutex::new(CacheStore::new(config.max_size, config.ttl)));
98        Self {
99            inner,
100            config,
101            store,
102        }
103    }
104}
105
106impl<S, Req, K, Resp> Clone for Cache<S, Req, K, Resp>
107where
108    S: Clone,
109{
110    fn clone(&self) -> Self {
111        Self {
112            inner: self.inner.clone(),
113            config: Arc::clone(&self.config),
114            store: Arc::clone(&self.store),
115        }
116    }
117}
118
119impl<S, Req, K> Service<Req> for Cache<S, Req, K, S::Response>
120where
121    S: Service<Req>,
122    S::Response: Clone + Send + 'static,
123    K: Hash + Eq + Clone + Send + 'static,
124    Req: Send + 'static,
125    S::Future: Send + 'static,
126{
127    type Response = S::Response;
128    type Error = CacheError<S::Error>;
129    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
130
131    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
132        self.inner.poll_ready(cx).map_err(CacheError::Inner)
133    }
134
135    fn call(&mut self, req: Req) -> Self::Future {
136        let key = (self.config.key_extractor)(&req);
137        let cache_name = self.config.name.clone();
138
139        // Check cache first
140        let cached = {
141            let mut store = self.store.lock().unwrap();
142            store.get(&key)
143        };
144
145        if let Some(response) = cached {
146            // Cache hit
147            #[cfg(feature = "metrics")]
148            {
149                counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "hit")
150                    .increment(1);
151            }
152
153            #[cfg(feature = "tracing")]
154            debug!(cache = %cache_name, "Cache hit");
155
156            let event = CacheEvent::Hit {
157                pattern_name: cache_name,
158                timestamp: Instant::now(),
159            };
160            self.config.event_listeners.emit(&event);
161            return Box::pin(async move { Ok(response) });
162        }
163
164        // Cache miss
165        #[cfg(feature = "metrics")]
166        {
167            counter!("cache_requests_total", "cache" => cache_name.clone(), "result" => "miss")
168                .increment(1);
169        }
170
171        #[cfg(feature = "tracing")]
172        debug!(cache = %cache_name, "Cache miss");
173
174        let miss_event = CacheEvent::Miss {
175            pattern_name: cache_name.clone(),
176            timestamp: Instant::now(),
177        };
178        self.config.event_listeners.emit(&miss_event);
179
180        let future = self.inner.call(req);
181        let store = Arc::clone(&self.store);
182        let config = Arc::clone(&self.config);
183
184        Box::pin(async move {
185            let response = future.await.map_err(CacheError::Inner)?;
186
187            // Store successful response in cache
188            let was_evicted = {
189                let mut store = store.lock().unwrap();
190                let was_full = store.len() >= config.max_size;
191                store.insert(key, response.clone());
192
193                // Update cache size gauge
194                #[cfg(feature = "metrics")]
195                {
196                    let new_size = store.len();
197                    gauge!("cache_size", "cache" => config.name.clone()).set(new_size as f64);
198                }
199
200                was_full
201            };
202
203            if was_evicted {
204                #[cfg(feature = "metrics")]
205                {
206                    counter!("cache_evictions_total", "cache" => config.name.clone()).increment(1);
207                }
208
209                #[cfg(feature = "tracing")]
210                info!(cache = %config.name, "Cache eviction occurred");
211
212                let event = CacheEvent::Eviction {
213                    pattern_name: config.name.clone(),
214                    timestamp: Instant::now(),
215                };
216                config.event_listeners.emit(&event);
217            }
218
219            Ok(response)
220        })
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use std::sync::atomic::{AtomicUsize, Ordering};
228    use std::time::Duration;
229    use tower::service_fn;
230    use tower::Layer;
231    use tower::ServiceExt;
232
233    #[tokio::test]
234    async fn cache_hit_returns_cached_response() {
235        let call_count = Arc::new(AtomicUsize::new(0));
236        let cc = Arc::clone(&call_count);
237
238        let service = service_fn(move |req: String| {
239            let cc = Arc::clone(&cc);
240            async move {
241                cc.fetch_add(1, Ordering::SeqCst);
242                Ok::<_, std::io::Error>(format!("Response: {}", req))
243            }
244        });
245
246        let layer = CacheLayer::builder()
247            .max_size(10)
248            .key_extractor(|req: &String| req.clone())
249            .build();
250
251        let mut service = layer.layer(service);
252
253        // First call - cache miss
254        let response1 = service
255            .ready()
256            .await
257            .unwrap()
258            .call("test".to_string())
259            .await
260            .unwrap();
261        assert_eq!(response1, "Response: test");
262        assert_eq!(call_count.load(Ordering::SeqCst), 1);
263
264        // Second call - cache hit
265        let response2 = service
266            .ready()
267            .await
268            .unwrap()
269            .call("test".to_string())
270            .await
271            .unwrap();
272        assert_eq!(response2, "Response: test");
273        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Not called again
274    }
275
276    #[tokio::test]
277    async fn cache_miss_calls_inner_service() {
278        let service = service_fn(|req: String| async move {
279            Ok::<_, std::io::Error>(format!("Response: {}", req))
280        });
281
282        let layer = CacheLayer::builder()
283            .max_size(10)
284            .key_extractor(|req: &String| req.clone())
285            .build();
286
287        let mut service = layer.layer(service);
288
289        let response = service
290            .ready()
291            .await
292            .unwrap()
293            .call("test".to_string())
294            .await
295            .unwrap();
296        assert_eq!(response, "Response: test");
297    }
298
299    #[tokio::test]
300    async fn different_keys_not_cached_together() {
301        let call_count = Arc::new(AtomicUsize::new(0));
302        let cc = Arc::clone(&call_count);
303
304        let service = service_fn(move |req: String| {
305            let cc = Arc::clone(&cc);
306            async move {
307                cc.fetch_add(1, Ordering::SeqCst);
308                Ok::<_, std::io::Error>(format!("Response: {}", req))
309            }
310        });
311
312        let layer = CacheLayer::builder()
313            .max_size(10)
314            .key_extractor(|req: &String| req.clone())
315            .build();
316
317        let mut service = layer.layer(service);
318
319        service
320            .ready()
321            .await
322            .unwrap()
323            .call("test1".to_string())
324            .await
325            .unwrap();
326        service
327            .ready()
328            .await
329            .unwrap()
330            .call("test2".to_string())
331            .await
332            .unwrap();
333
334        assert_eq!(call_count.load(Ordering::SeqCst), 2);
335    }
336
337    #[tokio::test]
338    async fn ttl_expiration_causes_cache_miss() {
339        let call_count = Arc::new(AtomicUsize::new(0));
340        let cc = Arc::clone(&call_count);
341
342        let service = service_fn(move |req: String| {
343            let cc = Arc::clone(&cc);
344            async move {
345                cc.fetch_add(1, Ordering::SeqCst);
346                Ok::<_, std::io::Error>(format!("Response: {}", req))
347            }
348        });
349
350        let layer = CacheLayer::builder()
351            .max_size(10)
352            .ttl(Duration::from_millis(50))
353            .key_extractor(|req: &String| req.clone())
354            .build();
355
356        let mut service = layer.layer(service);
357
358        service
359            .ready()
360            .await
361            .unwrap()
362            .call("test".to_string())
363            .await
364            .unwrap();
365        assert_eq!(call_count.load(Ordering::SeqCst), 1);
366
367        // Wait for TTL to expire
368        tokio::time::sleep(Duration::from_millis(100)).await;
369
370        service
371            .ready()
372            .await
373            .unwrap()
374            .call("test".to_string())
375            .await
376            .unwrap();
377        assert_eq!(call_count.load(Ordering::SeqCst), 2); // Called again
378    }
379
380    #[tokio::test]
381    async fn lru_eviction_removes_least_recently_used() {
382        let service = service_fn(|req: String| async move {
383            Ok::<_, std::io::Error>(format!("Response: {}", req))
384        });
385
386        let layer = CacheLayer::builder()
387            .max_size(2)
388            .key_extractor(|req: &String| req.clone())
389            .build();
390
391        let mut service = layer.layer(service);
392
393        // Fill cache with 2 items
394        service
395            .ready()
396            .await
397            .unwrap()
398            .call("key1".to_string())
399            .await
400            .unwrap();
401        service
402            .ready()
403            .await
404            .unwrap()
405            .call("key2".to_string())
406            .await
407            .unwrap();
408
409        // Add third item, should evict key1
410        service
411            .ready()
412            .await
413            .unwrap()
414            .call("key3".to_string())
415            .await
416            .unwrap();
417
418        // Verify cache state by checking call counts
419        let call_count = Arc::new(AtomicUsize::new(0));
420        let cc = Arc::clone(&call_count);
421
422        let service2 = service_fn(move |req: String| {
423            let cc = Arc::clone(&cc);
424            async move {
425                cc.fetch_add(1, Ordering::SeqCst);
426                Ok::<_, std::io::Error>(format!("Response: {}", req))
427            }
428        });
429
430        let mut service2 = layer.layer(service2);
431
432        // key1 should be evicted (cache miss)
433        service2
434            .ready()
435            .await
436            .unwrap()
437            .call("key1".to_string())
438            .await
439            .unwrap();
440        assert_eq!(call_count.load(Ordering::SeqCst), 1);
441    }
442
443    #[tokio::test]
444    async fn event_listeners_called() {
445        let hit_count = Arc::new(AtomicUsize::new(0));
446        let miss_count = Arc::new(AtomicUsize::new(0));
447        let eviction_count = Arc::new(AtomicUsize::new(0));
448
449        let hc = Arc::clone(&hit_count);
450        let mc = Arc::clone(&miss_count);
451        let ec = Arc::clone(&eviction_count);
452
453        let service = service_fn(|req: String| async move {
454            Ok::<_, std::io::Error>(format!("Response: {}", req))
455        });
456
457        let layer = CacheLayer::builder()
458            .max_size(1)
459            .key_extractor(|req: &String| req.clone())
460            .on_hit(move || {
461                hc.fetch_add(1, Ordering::SeqCst);
462            })
463            .on_miss(move || {
464                mc.fetch_add(1, Ordering::SeqCst);
465            })
466            .on_eviction(move || {
467                ec.fetch_add(1, Ordering::SeqCst);
468            })
469            .build();
470
471        let mut service = layer.layer(service);
472
473        // First call - miss
474        service
475            .ready()
476            .await
477            .unwrap()
478            .call("test".to_string())
479            .await
480            .unwrap();
481        assert_eq!(miss_count.load(Ordering::SeqCst), 1);
482        assert_eq!(hit_count.load(Ordering::SeqCst), 0);
483
484        // Second call - hit
485        service
486            .ready()
487            .await
488            .unwrap()
489            .call("test".to_string())
490            .await
491            .unwrap();
492        assert_eq!(hit_count.load(Ordering::SeqCst), 1);
493        assert_eq!(miss_count.load(Ordering::SeqCst), 1);
494
495        // Third call with different key - eviction
496        service
497            .ready()
498            .await
499            .unwrap()
500            .call("other".to_string())
501            .await
502            .unwrap();
503        assert_eq!(eviction_count.load(Ordering::SeqCst), 1);
504    }
505
506    #[tokio::test]
507    async fn errors_not_cached() {
508        let call_count = Arc::new(AtomicUsize::new(0));
509        let cc = Arc::clone(&call_count);
510
511        let service = service_fn(move |_req: String| {
512            let cc = Arc::clone(&cc);
513            async move {
514                cc.fetch_add(1, Ordering::SeqCst);
515                Err::<String, _>(std::io::Error::other("error"))
516            }
517        });
518
519        let layer = CacheLayer::builder()
520            .max_size(10)
521            .key_extractor(|req: &String| req.clone())
522            .build();
523
524        let mut service = layer.layer(service);
525
526        // First call - error
527        let _ = service
528            .ready()
529            .await
530            .unwrap()
531            .call("test".to_string())
532            .await;
533        assert_eq!(call_count.load(Ordering::SeqCst), 1);
534
535        // Second call - should call inner again (error not cached)
536        let _ = service
537            .ready()
538            .await
539            .unwrap()
540            .call("test".to_string())
541            .await;
542        assert_eq!(call_count.load(Ordering::SeqCst), 2);
543    }
544}