1use crate::events::CacheEvent;
4use crate::eviction::EvictionPolicy;
5use crate::store::CacheStore;
6use crate::{Cache, CacheConfig, KeyExtractor};
7use std::hash::Hash;
8use std::sync::{Arc, Mutex};
9use std::time::Duration;
10use tower::Layer;
11use tower_resilience_core::{EventListeners, FnListener};
12
13#[derive(Clone)]
71pub struct SharedCacheLayer<Req, K, Resp> {
72 config: Arc<CacheConfig<Req, K>>,
73 store: Arc<Mutex<CacheStore<K, Resp>>>,
74}
75
76impl<Req, K, Resp> SharedCacheLayer<Req, K, Resp>
77where
78 K: Hash + Eq + Clone + Send + 'static,
79 Resp: Clone + Send + 'static,
80{
81 pub fn new(config: CacheConfig<Req, K>) -> Self {
83 let store = Arc::new(Mutex::new(CacheStore::new(
84 config.max_size,
85 config.ttl,
86 config.eviction_policy,
87 )));
88 Self {
89 config: Arc::new(config),
90 store,
91 }
92 }
93
94 pub(crate) fn from_config(config: Arc<CacheConfig<Req, K>>) -> Self {
98 let store = Arc::new(Mutex::new(CacheStore::new(
99 config.max_size,
100 config.ttl,
101 config.eviction_policy,
102 )));
103 Self { config, store }
104 }
105
106 pub fn builder() -> SharedCacheConfigBuilder<Req, K, Resp> {
121 SharedCacheConfigBuilder::new()
122 }
123}
124
125impl<S, Req, K, Resp> Layer<S> for SharedCacheLayer<Req, K, Resp>
126where
127 K: Hash + Eq + Clone + Send + 'static,
128 S: tower::Service<Req, Response = Resp>,
129 Resp: Clone + Send + 'static,
130{
131 type Service = Cache<S, Req, K, Resp>;
132
133 fn layer(&self, service: S) -> Self::Service {
134 Cache::with_store(service, Arc::clone(&self.config), Arc::clone(&self.store))
135 }
136}
137
138pub struct SharedCacheConfigBuilder<Req, K, Resp> {
140 max_size: usize,
141 ttl: Option<Duration>,
142 eviction_policy: EvictionPolicy,
143 key_extractor: Option<KeyExtractor<Req, K>>,
144 event_listeners: EventListeners<CacheEvent>,
145 name: String,
146 _resp: std::marker::PhantomData<Resp>,
147}
148
149impl<Req, K, Resp> SharedCacheConfigBuilder<Req, K, Resp>
150where
151 K: Hash + Eq + Clone + Send + 'static,
152 Resp: Clone + Send + 'static,
153{
154 pub fn new() -> Self {
156 Self {
157 max_size: 100,
158 ttl: None,
159 eviction_policy: EvictionPolicy::default(),
160 key_extractor: None,
161 event_listeners: EventListeners::new(),
162 name: String::from("<unnamed>"),
163 _resp: std::marker::PhantomData,
164 }
165 }
166
167 pub fn max_size(mut self, size: usize) -> Self {
171 self.max_size = size;
172 self
173 }
174
175 pub fn ttl(mut self, ttl: Duration) -> Self {
180 self.ttl = Some(ttl);
181 self
182 }
183
184 pub fn eviction_policy(mut self, policy: EvictionPolicy) -> Self {
196 self.eviction_policy = policy;
197 self
198 }
199
200 pub fn key_extractor<F>(mut self, f: F) -> Self
204 where
205 F: Fn(&Req) -> K + Send + Sync + 'static,
206 {
207 self.key_extractor = Some(Arc::new(f));
208 self
209 }
210
211 pub fn name(mut self, name: impl Into<String>) -> Self {
215 self.name = name.into();
216 self
217 }
218
219 pub fn on_hit<F>(mut self, f: F) -> Self
221 where
222 F: Fn() + Send + Sync + 'static,
223 {
224 self.event_listeners.add(FnListener::new(move |event| {
225 if matches!(event, CacheEvent::Hit { .. }) {
226 f();
227 }
228 }));
229 self
230 }
231
232 pub fn on_miss<F>(mut self, f: F) -> Self
234 where
235 F: Fn() + Send + Sync + 'static,
236 {
237 self.event_listeners.add(FnListener::new(move |event| {
238 if matches!(event, CacheEvent::Miss { .. }) {
239 f();
240 }
241 }));
242 self
243 }
244
245 pub fn on_eviction<F>(mut self, f: F) -> Self
247 where
248 F: Fn() + Send + Sync + 'static,
249 {
250 self.event_listeners.add(FnListener::new(move |event| {
251 if matches!(event, CacheEvent::Eviction { .. }) {
252 f();
253 }
254 }));
255 self
256 }
257
258 pub fn build(self) -> SharedCacheLayer<Req, K, Resp> {
264 let key_extractor = self
265 .key_extractor
266 .expect("key_extractor must be set before building");
267
268 let config = CacheConfig {
269 max_size: self.max_size,
270 ttl: self.ttl,
271 eviction_policy: self.eviction_policy,
272 key_extractor,
273 event_listeners: self.event_listeners,
274 name: self.name,
275 };
276
277 SharedCacheLayer::new(config)
278 }
279}
280
281impl<Req, K, Resp> Default for SharedCacheConfigBuilder<Req, K, Resp>
282where
283 K: Hash + Eq + Clone + Send + 'static,
284 Resp: Clone + Send + 'static,
285{
286 fn default() -> Self {
287 Self::new()
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use std::sync::atomic::{AtomicUsize, Ordering};
295 use tower::service_fn;
296 use tower::{Service, ServiceExt};
297
298 #[derive(Clone, Hash, Eq, PartialEq)]
299 struct TestRequest {
300 id: String,
301 }
302
303 #[test]
304 fn test_shared_builder_defaults() {
305 let _layer: SharedCacheLayer<TestRequest, String, String> = SharedCacheLayer::builder()
306 .key_extractor(|req: &TestRequest| req.id.clone())
307 .build();
308 }
309
310 #[test]
311 fn test_shared_builder_custom_values() {
312 let _layer: SharedCacheLayer<TestRequest, String, String> = SharedCacheLayer::builder()
313 .max_size(500)
314 .ttl(Duration::from_secs(60))
315 .key_extractor(|req: &TestRequest| req.id.clone())
316 .name("my-shared-cache")
317 .build();
318 }
319
320 #[test]
321 #[should_panic(expected = "key_extractor must be set")]
322 fn test_shared_builder_panics_without_key_extractor() {
323 let _config: SharedCacheLayer<TestRequest, String, String> =
324 SharedCacheLayer::builder().build();
325 }
326
327 #[tokio::test]
328 async fn test_shared_cache_across_layer_calls() {
329 let call_count = Arc::new(AtomicUsize::new(0));
330 let cc1 = Arc::clone(&call_count);
331 let cc2 = Arc::clone(&call_count);
332
333 let service1 = service_fn(move |req: String| {
335 let cc = Arc::clone(&cc1);
336 async move {
337 cc.fetch_add(1, Ordering::SeqCst);
338 Ok::<_, std::io::Error>(format!("Response: {}", req))
339 }
340 });
341
342 let service2 = service_fn(move |req: String| {
343 let cc = Arc::clone(&cc2);
344 async move {
345 cc.fetch_add(1, Ordering::SeqCst);
346 Ok::<_, std::io::Error>(format!("Response: {}", req))
347 }
348 });
349
350 let shared_layer: SharedCacheLayer<String, String, String> = SharedCacheLayer::builder()
352 .max_size(10)
353 .key_extractor(|req: &String| req.clone())
354 .build();
355
356 let mut wrapped1 = shared_layer.clone().layer(service1);
358 let mut wrapped2 = shared_layer.layer(service2);
359
360 let response1 = wrapped1
362 .ready()
363 .await
364 .unwrap()
365 .call("test".to_string())
366 .await
367 .unwrap();
368 assert_eq!(response1, "Response: test");
369 assert_eq!(call_count.load(Ordering::SeqCst), 1);
370
371 let response2 = wrapped2
373 .ready()
374 .await
375 .unwrap()
376 .call("test".to_string())
377 .await
378 .unwrap();
379 assert_eq!(response2, "Response: test");
380 assert_eq!(call_count.load(Ordering::SeqCst), 1);
382 }
383
384 #[tokio::test]
385 async fn test_non_shared_cache_layer_creates_separate_stores() {
386 use crate::CacheLayer;
388
389 let call_count = Arc::new(AtomicUsize::new(0));
390 let cc1 = Arc::clone(&call_count);
391 let cc2 = Arc::clone(&call_count);
392
393 let service1 = service_fn(move |req: String| {
394 let cc = Arc::clone(&cc1);
395 async move {
396 cc.fetch_add(1, Ordering::SeqCst);
397 Ok::<_, std::io::Error>(format!("Response: {}", req))
398 }
399 });
400
401 let service2 = service_fn(move |req: String| {
402 let cc = Arc::clone(&cc2);
403 async move {
404 cc.fetch_add(1, Ordering::SeqCst);
405 Ok::<_, std::io::Error>(format!("Response: {}", req))
406 }
407 });
408
409 let layer = CacheLayer::builder()
411 .max_size(10)
412 .key_extractor(|req: &String| req.clone())
413 .build();
414
415 let mut wrapped1 = layer.clone().layer(service1);
417 let mut wrapped2 = layer.layer(service2);
418
419 wrapped1
421 .ready()
422 .await
423 .unwrap()
424 .call("test".to_string())
425 .await
426 .unwrap();
427 assert_eq!(call_count.load(Ordering::SeqCst), 1);
428
429 wrapped2
431 .ready()
432 .await
433 .unwrap()
434 .call("test".to_string())
435 .await
436 .unwrap();
437 assert_eq!(call_count.load(Ordering::SeqCst), 2);
439 }
440}