Skip to main content

tower_resilience_cache/
shared_layer.rs

1//! Shared cache layer that maintains a single store across all layer() calls.
2
3use crate::events::CacheEvent;
4use crate::eviction::EvictionPolicy;
5use crate::store::CacheStore;
6use crate::{Cache, CacheConfig, KeyExtractor};
7use std::hash::Hash;
8use std::sync::{Arc, Mutex};
9use std::time::Duration;
10use tower::Layer;
11use tower_resilience_core::{EventListeners, FnListener};
12
13/// A Tower [`Layer`] that applies response caching with a shared store.
14///
15/// Unlike [`CacheLayer`](crate::CacheLayer), this layer shares a single cache store
16/// across all services created via [`layer()`](Layer::layer). This is useful when
17/// multiple service instances (e.g., per-session or per-request services) need to
18/// share the same cache.
19///
20/// # Type Parameters
21///
22/// - `Req`: The request type
23/// - `K`: The cache key type (extracted from requests)
24/// - `Resp`: The response type that will be cached
25///
26/// # Examples
27///
28/// ```
29/// use tower_resilience_cache::SharedCacheLayer;
30/// use tower::ServiceBuilder;
31/// use std::time::Duration;
32///
33/// # async fn example() {
34/// // Create a shared cache layer
35/// let cache_layer: SharedCacheLayer<String, String, String> = SharedCacheLayer::builder()
36///     .max_size(100)
37///     .ttl(Duration::from_secs(60))
38///     .key_extractor(|req: &String| req.clone())
39///     .build();
40///
41/// // Both services share the same cache
42/// let service1 = ServiceBuilder::new()
43///     .layer(cache_layer.clone())
44///     .service(my_service());
45///
46/// let service2 = ServiceBuilder::new()
47///     .layer(cache_layer)
48///     .service(my_service());
49/// # }
50/// # fn my_service() -> impl tower::Service<String, Response = String, Error = std::io::Error> {
51/// #     tower::service_fn(|req: String| async move { Ok::<_, std::io::Error>(req) })
52/// # }
53/// ```
54///
55/// # Creating from CacheLayer
56///
57/// You can also convert an existing [`CacheLayer`](crate::CacheLayer) configuration:
58///
59/// ```
60/// use tower_resilience_cache::CacheLayer;
61/// use std::time::Duration;
62///
63/// let shared_cache = CacheLayer::builder()
64///     .max_size(100)
65///     .ttl(Duration::from_secs(60))
66///     .key_extractor(|req: &String| req.clone())
67///     .build()
68///     .shared::<String>();  // Specify the response type
69/// ```
70#[derive(Clone)]
71pub struct SharedCacheLayer<Req, K, Resp> {
72    config: Arc<CacheConfig<Req, K>>,
73    store: Arc<Mutex<CacheStore<K, Resp>>>,
74}
75
76impl<Req, K, Resp> SharedCacheLayer<Req, K, Resp>
77where
78    K: Hash + Eq + Clone + Send + 'static,
79    Resp: Clone + Send + 'static,
80{
81    /// Creates a new `SharedCacheLayer` with the given configuration.
82    pub fn new(config: CacheConfig<Req, K>) -> Self {
83        let store = Arc::new(Mutex::new(CacheStore::new(
84            config.max_size,
85            config.ttl,
86            config.eviction_policy,
87        )));
88        Self {
89            config: Arc::new(config),
90            store,
91        }
92    }
93
94    /// Creates a new `SharedCacheLayer` from an existing config Arc.
95    ///
96    /// This is used by [`CacheLayer::shared()`](crate::CacheLayer::shared).
97    pub(crate) fn from_config(config: Arc<CacheConfig<Req, K>>) -> Self {
98        let store = Arc::new(Mutex::new(CacheStore::new(
99            config.max_size,
100            config.ttl,
101            config.eviction_policy,
102        )));
103        Self { config, store }
104    }
105
106    /// Creates a new builder for configuring a shared cache layer.
107    ///
108    /// # Examples
109    ///
110    /// ```
111    /// use tower_resilience_cache::SharedCacheLayer;
112    /// use std::time::Duration;
113    ///
114    /// let layer: SharedCacheLayer<String, String, String> = SharedCacheLayer::builder()
115    ///     .max_size(100)
116    ///     .ttl(Duration::from_secs(60))
117    ///     .key_extractor(|req: &String| req.clone())
118    ///     .build();
119    /// ```
120    pub fn builder() -> SharedCacheConfigBuilder<Req, K, Resp> {
121        SharedCacheConfigBuilder::new()
122    }
123}
124
125impl<S, Req, K, Resp> Layer<S> for SharedCacheLayer<Req, K, Resp>
126where
127    K: Hash + Eq + Clone + Send + 'static,
128    S: tower::Service<Req, Response = Resp>,
129    Resp: Clone + Send + 'static,
130{
131    type Service = Cache<S, Req, K, Resp>;
132
133    fn layer(&self, service: S) -> Self::Service {
134        Cache::with_store(service, Arc::clone(&self.config), Arc::clone(&self.store))
135    }
136}
137
138/// Builder for configuring and constructing a shared cache layer.
139pub struct SharedCacheConfigBuilder<Req, K, Resp> {
140    max_size: usize,
141    ttl: Option<Duration>,
142    eviction_policy: EvictionPolicy,
143    key_extractor: Option<KeyExtractor<Req, K>>,
144    event_listeners: EventListeners<CacheEvent>,
145    name: String,
146    _resp: std::marker::PhantomData<Resp>,
147}
148
149impl<Req, K, Resp> SharedCacheConfigBuilder<Req, K, Resp>
150where
151    K: Hash + Eq + Clone + Send + 'static,
152    Resp: Clone + Send + 'static,
153{
154    /// Creates a new builder with default values.
155    pub fn new() -> Self {
156        Self {
157            max_size: 100,
158            ttl: None,
159            eviction_policy: EvictionPolicy::default(),
160            key_extractor: None,
161            event_listeners: EventListeners::new(),
162            name: String::from("<unnamed>"),
163            _resp: std::marker::PhantomData,
164        }
165    }
166
167    /// Sets the maximum number of entries in the cache.
168    ///
169    /// Default: 100
170    pub fn max_size(mut self, size: usize) -> Self {
171        self.max_size = size;
172        self
173    }
174
175    /// Sets the time-to-live for cached entries.
176    ///
177    /// If set, entries will expire after the specified duration.
178    /// Default: None (no expiration)
179    pub fn ttl(mut self, ttl: Duration) -> Self {
180        self.ttl = Some(ttl);
181        self
182    }
183
184    /// Sets the eviction policy for the cache.
185    ///
186    /// Determines which entry to evict when the cache reaches capacity.
187    ///
188    /// # Options
189    ///
190    /// - `EvictionPolicy::Lru` - Least Recently Used (default)
191    /// - `EvictionPolicy::Lfu` - Least Frequently Used
192    /// - `EvictionPolicy::Fifo` - First In, First Out
193    ///
194    /// Default: `EvictionPolicy::Lru`
195    pub fn eviction_policy(mut self, policy: EvictionPolicy) -> Self {
196        self.eviction_policy = policy;
197        self
198    }
199
200    /// Sets the function that extracts a cache key from a request.
201    ///
202    /// This function must be provided before building.
203    pub fn key_extractor<F>(mut self, f: F) -> Self
204    where
205        F: Fn(&Req) -> K + Send + Sync + 'static,
206    {
207        self.key_extractor = Some(Arc::new(f));
208        self
209    }
210
211    /// Sets the name of this cache instance for observability.
212    ///
213    /// Default: `"<unnamed>"`
214    pub fn name(mut self, name: impl Into<String>) -> Self {
215        self.name = name.into();
216        self
217    }
218
219    /// Registers a callback when a cache hit occurs.
220    pub fn on_hit<F>(mut self, f: F) -> Self
221    where
222        F: Fn() + Send + Sync + 'static,
223    {
224        self.event_listeners.add(FnListener::new(move |event| {
225            if matches!(event, CacheEvent::Hit { .. }) {
226                f();
227            }
228        }));
229        self
230    }
231
232    /// Registers a callback when a cache miss occurs.
233    pub fn on_miss<F>(mut self, f: F) -> Self
234    where
235        F: Fn() + Send + Sync + 'static,
236    {
237        self.event_listeners.add(FnListener::new(move |event| {
238            if matches!(event, CacheEvent::Miss { .. }) {
239                f();
240            }
241        }));
242        self
243    }
244
245    /// Registers a callback when an entry is evicted from the cache.
246    pub fn on_eviction<F>(mut self, f: F) -> Self
247    where
248        F: Fn() + Send + Sync + 'static,
249    {
250        self.event_listeners.add(FnListener::new(move |event| {
251            if matches!(event, CacheEvent::Eviction { .. }) {
252                f();
253            }
254        }));
255        self
256    }
257
258    /// Builds the shared cache layer.
259    ///
260    /// # Panics
261    ///
262    /// Panics if `key_extractor` was not set.
263    pub fn build(self) -> SharedCacheLayer<Req, K, Resp> {
264        let key_extractor = self
265            .key_extractor
266            .expect("key_extractor must be set before building");
267
268        let config = CacheConfig {
269            max_size: self.max_size,
270            ttl: self.ttl,
271            eviction_policy: self.eviction_policy,
272            key_extractor,
273            event_listeners: self.event_listeners,
274            name: self.name,
275        };
276
277        SharedCacheLayer::new(config)
278    }
279}
280
281impl<Req, K, Resp> Default for SharedCacheConfigBuilder<Req, K, Resp>
282where
283    K: Hash + Eq + Clone + Send + 'static,
284    Resp: Clone + Send + 'static,
285{
286    fn default() -> Self {
287        Self::new()
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use std::sync::atomic::{AtomicUsize, Ordering};
295    use tower::service_fn;
296    use tower::{Service, ServiceExt};
297
298    #[derive(Clone, Hash, Eq, PartialEq)]
299    struct TestRequest {
300        id: String,
301    }
302
303    #[test]
304    fn test_shared_builder_defaults() {
305        let _layer: SharedCacheLayer<TestRequest, String, String> = SharedCacheLayer::builder()
306            .key_extractor(|req: &TestRequest| req.id.clone())
307            .build();
308    }
309
310    #[test]
311    fn test_shared_builder_custom_values() {
312        let _layer: SharedCacheLayer<TestRequest, String, String> = SharedCacheLayer::builder()
313            .max_size(500)
314            .ttl(Duration::from_secs(60))
315            .key_extractor(|req: &TestRequest| req.id.clone())
316            .name("my-shared-cache")
317            .build();
318    }
319
320    #[test]
321    #[should_panic(expected = "key_extractor must be set")]
322    fn test_shared_builder_panics_without_key_extractor() {
323        let _config: SharedCacheLayer<TestRequest, String, String> =
324            SharedCacheLayer::builder().build();
325    }
326
327    #[tokio::test]
328    async fn test_shared_cache_across_layer_calls() {
329        let call_count = Arc::new(AtomicUsize::new(0));
330        let cc1 = Arc::clone(&call_count);
331        let cc2 = Arc::clone(&call_count);
332
333        // Create two separate services
334        let service1 = service_fn(move |req: String| {
335            let cc = Arc::clone(&cc1);
336            async move {
337                cc.fetch_add(1, Ordering::SeqCst);
338                Ok::<_, std::io::Error>(format!("Response: {}", req))
339            }
340        });
341
342        let service2 = service_fn(move |req: String| {
343            let cc = Arc::clone(&cc2);
344            async move {
345                cc.fetch_add(1, Ordering::SeqCst);
346                Ok::<_, std::io::Error>(format!("Response: {}", req))
347            }
348        });
349
350        // Create a shared cache layer
351        let shared_layer: SharedCacheLayer<String, String, String> = SharedCacheLayer::builder()
352            .max_size(10)
353            .key_extractor(|req: &String| req.clone())
354            .build();
355
356        // Apply to both services
357        let mut wrapped1 = shared_layer.clone().layer(service1);
358        let mut wrapped2 = shared_layer.layer(service2);
359
360        // First call on service1 - cache miss
361        let response1 = wrapped1
362            .ready()
363            .await
364            .unwrap()
365            .call("test".to_string())
366            .await
367            .unwrap();
368        assert_eq!(response1, "Response: test");
369        assert_eq!(call_count.load(Ordering::SeqCst), 1);
370
371        // Call on service2 with same key - should be cache HIT (shared store!)
372        let response2 = wrapped2
373            .ready()
374            .await
375            .unwrap()
376            .call("test".to_string())
377            .await
378            .unwrap();
379        assert_eq!(response2, "Response: test");
380        // Call count should still be 1 because cache was shared
381        assert_eq!(call_count.load(Ordering::SeqCst), 1);
382    }
383
384    #[tokio::test]
385    async fn test_non_shared_cache_layer_creates_separate_stores() {
386        // This test demonstrates the problem that SharedCacheLayer solves
387        use crate::CacheLayer;
388
389        let call_count = Arc::new(AtomicUsize::new(0));
390        let cc1 = Arc::clone(&call_count);
391        let cc2 = Arc::clone(&call_count);
392
393        let service1 = service_fn(move |req: String| {
394            let cc = Arc::clone(&cc1);
395            async move {
396                cc.fetch_add(1, Ordering::SeqCst);
397                Ok::<_, std::io::Error>(format!("Response: {}", req))
398            }
399        });
400
401        let service2 = service_fn(move |req: String| {
402            let cc = Arc::clone(&cc2);
403            async move {
404                cc.fetch_add(1, Ordering::SeqCst);
405                Ok::<_, std::io::Error>(format!("Response: {}", req))
406            }
407        });
408
409        // Regular CacheLayer (not shared)
410        let layer = CacheLayer::builder()
411            .max_size(10)
412            .key_extractor(|req: &String| req.clone())
413            .build();
414
415        // Apply to both services
416        let mut wrapped1 = layer.clone().layer(service1);
417        let mut wrapped2 = layer.layer(service2);
418
419        // First call on service1 - cache miss
420        wrapped1
421            .ready()
422            .await
423            .unwrap()
424            .call("test".to_string())
425            .await
426            .unwrap();
427        assert_eq!(call_count.load(Ordering::SeqCst), 1);
428
429        // Call on service2 with same key - ALSO a cache miss (separate stores!)
430        wrapped2
431            .ready()
432            .await
433            .unwrap()
434            .call("test".to_string())
435            .await
436            .unwrap();
437        // Call count is 2 because stores are NOT shared
438        assert_eq!(call_count.load(Ordering::SeqCst), 2);
439    }
440}