tower_resilience_cache/
lib.rs

1//! Response caching middleware for Tower services.
2//!
3//! This crate provides a Tower middleware for caching service responses,
4//! reducing load on downstream services by storing and reusing responses
5//! for identical requests.
6//!
7//! # Features
8//!
9//! - **LRU Eviction**: Least Recently Used eviction policy
10//! - **TTL Support**: Optional time-to-live for cache entries
11//! - **Event System**: Observability through cache events (Hit, Miss, Eviction)
12//! - **Flexible Key Extraction**: User-defined key extraction from requests
13//!
14//! # Examples
15//!
16//! ```
17//! use tower_resilience_cache::CacheLayer;
18//! use tower::ServiceBuilder;
19//! use std::time::Duration;
20//!
21//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
22//! // Create a cache layer
23//! let cache_layer = CacheLayer::builder()
24//!     .max_size(100)
25//!     .ttl(Duration::from_secs(60))
26//!     .key_extractor(|req: &String| req.clone())
27//!     .on_hit(|| println!("Cache hit!"))
28//!     .on_miss(|| println!("Cache miss!"))
29//!     .build();
30//!
31//! // Apply to a service
32//! let service = ServiceBuilder::new()
33//!     .layer(cache_layer)
34//!     .service(tower::service_fn(|req: String| async move {
35//!         Ok::<_, std::io::Error>(format!("Response: {}", req))
36//!     }));
37//! # Ok(())
38//! # }
39//! ```
40
41mod config;
42mod error;
43mod events;
44mod layer;
45mod store;
46
47pub use config::{CacheConfig, CacheConfigBuilder, KeyExtractor};
48pub use error::CacheError;
49pub use events::CacheEvent;
50pub use layer::CacheLayer;
51
52use futures::future::BoxFuture;
53use std::hash::Hash;
54use std::sync::{Arc, Mutex};
55use std::task::{Context, Poll};
56use std::time::Instant;
57use store::CacheStore;
58use tower::Service;
59
60/// A Tower [`Service`] that caches responses.
61///
62/// This service wraps an inner service and caches successful responses.
63/// When a request comes in, the cache checks if a valid cached response
64/// exists. If so, it returns the cached value immediately without calling
65/// the inner service.
66///
67/// Responses must implement `Clone` to be cacheable.
68pub struct Cache<S, Req, K, Resp> {
69    inner: S,
70    config: Arc<CacheConfig<Req, K>>,
71    store: Arc<Mutex<CacheStore<K, Resp>>>,
72}
73
74impl<S, Req, K, Resp> Cache<S, Req, K, Resp>
75where
76    K: Hash + Eq,
77    Resp: Clone,
78{
79    /// Creates a new `Cache` wrapping the given service.
80    pub fn new(inner: S, config: Arc<CacheConfig<Req, K>>) -> Self {
81        let store = Arc::new(Mutex::new(CacheStore::new(config.max_size, config.ttl)));
82        Self {
83            inner,
84            config,
85            store,
86        }
87    }
88}
89
90impl<S, Req, K, Resp> Clone for Cache<S, Req, K, Resp>
91where
92    S: Clone,
93{
94    fn clone(&self) -> Self {
95        Self {
96            inner: self.inner.clone(),
97            config: Arc::clone(&self.config),
98            store: Arc::clone(&self.store),
99        }
100    }
101}
102
103impl<S, Req, K> Service<Req> for Cache<S, Req, K, S::Response>
104where
105    S: Service<Req>,
106    S::Response: Clone + Send + 'static,
107    K: Hash + Eq + Clone + Send + 'static,
108    Req: Send + 'static,
109    S::Future: Send + 'static,
110{
111    type Response = S::Response;
112    type Error = CacheError<S::Error>;
113    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
114
115    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116        self.inner.poll_ready(cx).map_err(CacheError::Inner)
117    }
118
119    fn call(&mut self, req: Req) -> Self::Future {
120        let key = (self.config.key_extractor)(&req);
121
122        // Check cache first
123        let cached = {
124            let mut store = self.store.lock().unwrap();
125            store.get(&key)
126        };
127
128        if let Some(response) = cached {
129            // Cache hit
130            let event = CacheEvent::Hit {
131                pattern_name: self.config.name.clone(),
132                timestamp: Instant::now(),
133            };
134            self.config.event_listeners.emit(&event);
135            return Box::pin(async move { Ok(response) });
136        }
137
138        // Cache miss
139        let miss_event = CacheEvent::Miss {
140            pattern_name: self.config.name.clone(),
141            timestamp: Instant::now(),
142        };
143        self.config.event_listeners.emit(&miss_event);
144
145        let future = self.inner.call(req);
146        let store = Arc::clone(&self.store);
147        let config = Arc::clone(&self.config);
148
149        Box::pin(async move {
150            let response = future.await.map_err(CacheError::Inner)?;
151
152            // Store successful response in cache
153            let was_evicted = {
154                let mut store = store.lock().unwrap();
155                let was_full = store.len() >= config.max_size;
156                store.insert(key, response.clone());
157                was_full
158            };
159
160            if was_evicted {
161                let event = CacheEvent::Eviction {
162                    pattern_name: config.name.clone(),
163                    timestamp: Instant::now(),
164                };
165                config.event_listeners.emit(&event);
166            }
167
168            Ok(response)
169        })
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use std::sync::atomic::{AtomicUsize, Ordering};
177    use std::time::Duration;
178    use tower::service_fn;
179    use tower::Layer;
180    use tower::ServiceExt;
181
182    #[tokio::test]
183    async fn cache_hit_returns_cached_response() {
184        let call_count = Arc::new(AtomicUsize::new(0));
185        let cc = Arc::clone(&call_count);
186
187        let service = service_fn(move |req: String| {
188            let cc = Arc::clone(&cc);
189            async move {
190                cc.fetch_add(1, Ordering::SeqCst);
191                Ok::<_, std::io::Error>(format!("Response: {}", req))
192            }
193        });
194
195        let layer = CacheLayer::builder()
196            .max_size(10)
197            .key_extractor(|req: &String| req.clone())
198            .build();
199
200        let mut service = layer.layer(service);
201
202        // First call - cache miss
203        let response1 = service
204            .ready()
205            .await
206            .unwrap()
207            .call("test".to_string())
208            .await
209            .unwrap();
210        assert_eq!(response1, "Response: test");
211        assert_eq!(call_count.load(Ordering::SeqCst), 1);
212
213        // Second call - cache hit
214        let response2 = service
215            .ready()
216            .await
217            .unwrap()
218            .call("test".to_string())
219            .await
220            .unwrap();
221        assert_eq!(response2, "Response: test");
222        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Not called again
223    }
224
225    #[tokio::test]
226    async fn cache_miss_calls_inner_service() {
227        let service = service_fn(|req: String| async move {
228            Ok::<_, std::io::Error>(format!("Response: {}", req))
229        });
230
231        let layer = CacheLayer::builder()
232            .max_size(10)
233            .key_extractor(|req: &String| req.clone())
234            .build();
235
236        let mut service = layer.layer(service);
237
238        let response = service
239            .ready()
240            .await
241            .unwrap()
242            .call("test".to_string())
243            .await
244            .unwrap();
245        assert_eq!(response, "Response: test");
246    }
247
248    #[tokio::test]
249    async fn different_keys_not_cached_together() {
250        let call_count = Arc::new(AtomicUsize::new(0));
251        let cc = Arc::clone(&call_count);
252
253        let service = service_fn(move |req: String| {
254            let cc = Arc::clone(&cc);
255            async move {
256                cc.fetch_add(1, Ordering::SeqCst);
257                Ok::<_, std::io::Error>(format!("Response: {}", req))
258            }
259        });
260
261        let layer = CacheLayer::builder()
262            .max_size(10)
263            .key_extractor(|req: &String| req.clone())
264            .build();
265
266        let mut service = layer.layer(service);
267
268        service
269            .ready()
270            .await
271            .unwrap()
272            .call("test1".to_string())
273            .await
274            .unwrap();
275        service
276            .ready()
277            .await
278            .unwrap()
279            .call("test2".to_string())
280            .await
281            .unwrap();
282
283        assert_eq!(call_count.load(Ordering::SeqCst), 2);
284    }
285
286    #[tokio::test]
287    async fn ttl_expiration_causes_cache_miss() {
288        let call_count = Arc::new(AtomicUsize::new(0));
289        let cc = Arc::clone(&call_count);
290
291        let service = service_fn(move |req: String| {
292            let cc = Arc::clone(&cc);
293            async move {
294                cc.fetch_add(1, Ordering::SeqCst);
295                Ok::<_, std::io::Error>(format!("Response: {}", req))
296            }
297        });
298
299        let layer = CacheLayer::builder()
300            .max_size(10)
301            .ttl(Duration::from_millis(50))
302            .key_extractor(|req: &String| req.clone())
303            .build();
304
305        let mut service = layer.layer(service);
306
307        service
308            .ready()
309            .await
310            .unwrap()
311            .call("test".to_string())
312            .await
313            .unwrap();
314        assert_eq!(call_count.load(Ordering::SeqCst), 1);
315
316        // Wait for TTL to expire
317        tokio::time::sleep(Duration::from_millis(100)).await;
318
319        service
320            .ready()
321            .await
322            .unwrap()
323            .call("test".to_string())
324            .await
325            .unwrap();
326        assert_eq!(call_count.load(Ordering::SeqCst), 2); // Called again
327    }
328
329    #[tokio::test]
330    async fn lru_eviction_removes_least_recently_used() {
331        let service = service_fn(|req: String| async move {
332            Ok::<_, std::io::Error>(format!("Response: {}", req))
333        });
334
335        let layer = CacheLayer::builder()
336            .max_size(2)
337            .key_extractor(|req: &String| req.clone())
338            .build();
339
340        let mut service = layer.layer(service);
341
342        // Fill cache with 2 items
343        service
344            .ready()
345            .await
346            .unwrap()
347            .call("key1".to_string())
348            .await
349            .unwrap();
350        service
351            .ready()
352            .await
353            .unwrap()
354            .call("key2".to_string())
355            .await
356            .unwrap();
357
358        // Add third item, should evict key1
359        service
360            .ready()
361            .await
362            .unwrap()
363            .call("key3".to_string())
364            .await
365            .unwrap();
366
367        // Verify cache state by checking call counts
368        let call_count = Arc::new(AtomicUsize::new(0));
369        let cc = Arc::clone(&call_count);
370
371        let service2 = service_fn(move |req: String| {
372            let cc = Arc::clone(&cc);
373            async move {
374                cc.fetch_add(1, Ordering::SeqCst);
375                Ok::<_, std::io::Error>(format!("Response: {}", req))
376            }
377        });
378
379        let mut service2 = layer.layer(service2);
380
381        // key1 should be evicted (cache miss)
382        service2
383            .ready()
384            .await
385            .unwrap()
386            .call("key1".to_string())
387            .await
388            .unwrap();
389        assert_eq!(call_count.load(Ordering::SeqCst), 1);
390    }
391
392    #[tokio::test]
393    async fn event_listeners_called() {
394        let hit_count = Arc::new(AtomicUsize::new(0));
395        let miss_count = Arc::new(AtomicUsize::new(0));
396        let eviction_count = Arc::new(AtomicUsize::new(0));
397
398        let hc = Arc::clone(&hit_count);
399        let mc = Arc::clone(&miss_count);
400        let ec = Arc::clone(&eviction_count);
401
402        let service = service_fn(|req: String| async move {
403            Ok::<_, std::io::Error>(format!("Response: {}", req))
404        });
405
406        let layer = CacheLayer::builder()
407            .max_size(1)
408            .key_extractor(|req: &String| req.clone())
409            .on_hit(move || {
410                hc.fetch_add(1, Ordering::SeqCst);
411            })
412            .on_miss(move || {
413                mc.fetch_add(1, Ordering::SeqCst);
414            })
415            .on_eviction(move || {
416                ec.fetch_add(1, Ordering::SeqCst);
417            })
418            .build();
419
420        let mut service = layer.layer(service);
421
422        // First call - miss
423        service
424            .ready()
425            .await
426            .unwrap()
427            .call("test".to_string())
428            .await
429            .unwrap();
430        assert_eq!(miss_count.load(Ordering::SeqCst), 1);
431        assert_eq!(hit_count.load(Ordering::SeqCst), 0);
432
433        // Second call - hit
434        service
435            .ready()
436            .await
437            .unwrap()
438            .call("test".to_string())
439            .await
440            .unwrap();
441        assert_eq!(hit_count.load(Ordering::SeqCst), 1);
442        assert_eq!(miss_count.load(Ordering::SeqCst), 1);
443
444        // Third call with different key - eviction
445        service
446            .ready()
447            .await
448            .unwrap()
449            .call("other".to_string())
450            .await
451            .unwrap();
452        assert_eq!(eviction_count.load(Ordering::SeqCst), 1);
453    }
454
455    #[tokio::test]
456    async fn errors_not_cached() {
457        let call_count = Arc::new(AtomicUsize::new(0));
458        let cc = Arc::clone(&call_count);
459
460        let service = service_fn(move |_req: String| {
461            let cc = Arc::clone(&cc);
462            async move {
463                cc.fetch_add(1, Ordering::SeqCst);
464                Err::<String, _>(std::io::Error::other("error"))
465            }
466        });
467
468        let layer = CacheLayer::builder()
469            .max_size(10)
470            .key_extractor(|req: &String| req.clone())
471            .build();
472
473        let mut service = layer.layer(service);
474
475        // First call - error
476        let _ = service
477            .ready()
478            .await
479            .unwrap()
480            .call("test".to_string())
481            .await;
482        assert_eq!(call_count.load(Ordering::SeqCst), 1);
483
484        // Second call - should call inner again (error not cached)
485        let _ = service
486            .ready()
487            .await
488            .unwrap()
489            .call("test".to_string())
490            .await;
491        assert_eq!(call_count.load(Ordering::SeqCst), 2);
492    }
493}