tower_http_cache/backend/
multi_tier.rs1use async_trait::async_trait;
8use dashmap::DashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use super::{CacheBackend, CacheEntry, CacheRead};
14use crate::error::CacheError;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum PromotionStrategy {
19 HitCount { threshold: u64 },
21
22 HitRate { threshold_per_minute: u64 },
24}
25
26impl Default for PromotionStrategy {
27 fn default() -> Self {
28 Self::HitCount { threshold: 3 }
29 }
30}
31
32#[derive(Debug, Clone, Default)]
34pub struct TierStats {
35 pub l1_hits: u64,
36 pub l2_hits: u64,
37 pub misses: u64,
38 pub promotions: u64,
39}
40
41#[derive(Debug, Clone)]
43pub struct MultiTierConfig {
44 pub promotion_strategy: PromotionStrategy,
46
47 pub write_through: bool,
49
50 pub max_l1_entry_size: Option<usize>,
53}
54
55impl Default for MultiTierConfig {
56 fn default() -> Self {
57 Self {
58 promotion_strategy: PromotionStrategy::default(),
59 write_through: true,
60 max_l1_entry_size: Some(256 * 1024), }
62 }
63}
64
65struct KeyStats {
67 l2_hits: AtomicU64,
68}
69
70impl KeyStats {
71 fn new() -> Self {
72 Self {
73 l2_hits: AtomicU64::new(0),
74 }
75 }
76
77 fn record_hit(&self) -> u64 {
78 self.l2_hits.fetch_add(1, Ordering::Relaxed) + 1
79 }
80
81 fn reset(&self) {
82 self.l2_hits.store(0, Ordering::Relaxed);
83 }
84
85 fn hits(&self) -> u64 {
86 self.l2_hits.load(Ordering::Relaxed)
87 }
88}
89
90#[derive(Clone)]
95pub struct MultiTierBackend<L1, L2> {
96 l1: L1,
97 l2: L2,
98 config: MultiTierConfig,
99 key_stats: Arc<DashMap<String, Arc<KeyStats>>>,
100 tier_stats: Arc<TierStats>,
101}
102
103impl<L1, L2> MultiTierBackend<L1, L2>
104where
105 L1: CacheBackend,
106 L2: CacheBackend,
107{
108 pub fn new(l1: L1, l2: L2) -> Self {
110 Self {
111 l1,
112 l2,
113 config: MultiTierConfig::default(),
114 key_stats: Arc::new(DashMap::new()),
115 tier_stats: Arc::new(TierStats::default()),
116 }
117 }
118
119 pub fn builder() -> MultiTierBuilder<L1, L2> {
121 MultiTierBuilder::new()
122 }
123
124 pub fn l1(&self) -> &L1 {
126 &self.l1
127 }
128
129 pub fn l2(&self) -> &L2 {
131 &self.l2
132 }
133
134 pub fn stats(&self) -> &TierStats {
136 &self.tier_stats
137 }
138
139 fn should_promote(&self, key: &str) -> bool {
141 let stats = self
142 .key_stats
143 .entry(key.to_string())
144 .or_insert_with(|| Arc::new(KeyStats::new()));
145
146 match self.config.promotion_strategy {
147 PromotionStrategy::HitCount { threshold } => stats.hits() >= threshold,
148 PromotionStrategy::HitRate {
149 threshold_per_minute: _,
150 } => {
151 stats.hits() >= 3
154 }
155 }
156 }
157
158 fn record_hit(&self, key: &str) -> u64 {
160 self.key_stats
161 .entry(key.to_string())
162 .or_insert_with(|| Arc::new(KeyStats::new()))
163 .record_hit()
164 }
165
166 #[allow(dead_code)]
168 async fn promote(
169 &self,
170 key: &str,
171 entry: CacheEntry,
172 ttl: Duration,
173 stale_for: Duration,
174 ) -> Result<(), CacheError> {
175 self.l1.set(key.to_string(), entry, ttl, stale_for).await?;
177
178 if let Some(stats) = self.key_stats.get(key) {
180 stats.reset();
181 }
182
183 Ok(())
184 }
185}
186
187#[async_trait]
188impl<L1, L2> CacheBackend for MultiTierBackend<L1, L2>
189where
190 L1: CacheBackend,
191 L2: CacheBackend,
192{
193 async fn get(&self, key: &str) -> Result<Option<CacheRead>, CacheError> {
194 if let Some(entry) = self.l1.get(key).await? {
196 #[cfg(feature = "metrics")]
197 metrics::counter!("tower_http_cache.tier.l1_hit").increment(1);
198 return Ok(Some(entry));
199 }
200
201 if let Some(read) = self.l2.get(key).await? {
203 #[cfg(feature = "metrics")]
204 metrics::counter!("tower_http_cache.tier.l2_hit").increment(1);
205
206 self.record_hit(key);
208
209 if self.should_promote(key) {
210 let entry_size = read.entry.body.len();
211
212 let should_promote_l1 = if let Some(max_size) = self.config.max_l1_entry_size {
214 entry_size <= max_size
215 } else {
216 true
217 };
218
219 if should_promote_l1 {
220 #[cfg(feature = "metrics")]
221 metrics::counter!("tower_http_cache.tier.promoted").increment(1);
222
223 let ttl = if let Some(expires_at) = read.expires_at {
225 expires_at
226 .duration_since(std::time::SystemTime::now())
227 .unwrap_or(Duration::from_secs(60))
228 } else {
229 Duration::from_secs(60)
230 };
231
232 let stale_for = if let (Some(stale_until), Some(expires_at)) =
233 (read.stale_until, read.expires_at)
234 {
235 stale_until.duration_since(expires_at).unwrap_or_default()
236 } else {
237 Duration::ZERO
238 };
239
240 let entry = read.entry.clone();
242 let key = key.to_string();
243 let l1 = self.l1.clone();
244 let key_stats = self.key_stats.clone();
245
246 tokio::spawn(async move {
247 let _ = l1.set(key.clone(), entry, ttl, stale_for).await;
248 if let Some(stats) = key_stats.get(&key) {
249 stats.reset();
250 }
251 });
252 } else {
253 #[cfg(feature = "metrics")]
254 metrics::counter!("tower_http_cache.tier.promotion_skipped_large").increment(1);
255
256 #[cfg(feature = "tracing")]
257 tracing::debug!(
258 key = %key,
259 size = entry_size,
260 max_l1_size = ?self.config.max_l1_entry_size,
261 "skipping promotion for large entry"
262 );
263 }
264 }
265
266 return Ok(Some(read));
267 }
268
269 Ok(None)
270 }
271
272 async fn set(
273 &self,
274 key: String,
275 entry: CacheEntry,
276 ttl: Duration,
277 stale_for: Duration,
278 ) -> Result<(), CacheError> {
279 let entry_size = entry.body.len();
280
281 self.l2
283 .set(key.clone(), entry.clone(), ttl, stale_for)
284 .await?;
285
286 if self.config.write_through {
288 let should_write_l1 = if let Some(max_size) = self.config.max_l1_entry_size {
289 if entry_size <= max_size {
290 true
291 } else {
292 #[cfg(feature = "metrics")]
293 metrics::counter!("tower_http_cache.tier.l1_skipped_large").increment(1);
294
295 #[cfg(feature = "tracing")]
296 tracing::debug!(
297 key = %key,
298 size = entry_size,
299 max_l1_size = max_size,
300 "skipping L1 write for large entry"
301 );
302
303 false
304 }
305 } else {
306 true
307 };
308
309 if should_write_l1 {
310 let _ = self.l1.set(key.clone(), entry, ttl, stale_for).await;
311 }
312 }
313
314 Ok(())
315 }
316
317 async fn invalidate(&self, key: &str) -> Result<(), CacheError> {
318 let l1_result = self.l1.invalidate(key).await;
320 let l2_result = self.l2.invalidate(key).await;
321
322 self.key_stats.remove(key);
324
325 l1_result.and(l2_result)
327 }
328
329 async fn get_keys_by_tag(&self, tag: &str) -> Result<Vec<String>, CacheError> {
330 let mut keys = self.l1.get_keys_by_tag(tag).await?;
332 let l2_keys = self.l2.get_keys_by_tag(tag).await?;
333
334 keys.extend(l2_keys);
336 keys.sort();
337 keys.dedup();
338
339 Ok(keys)
340 }
341
342 async fn invalidate_by_tag(&self, tag: &str) -> Result<usize, CacheError> {
343 let l1_count = self.l1.invalidate_by_tag(tag).await?;
345 let l2_count = self.l2.invalidate_by_tag(tag).await?;
346
347 Ok(l1_count + l2_count)
348 }
349
350 async fn list_tags(&self) -> Result<Vec<String>, CacheError> {
351 let mut tags = self.l1.list_tags().await?;
353 let l2_tags = self.l2.list_tags().await?;
354
355 tags.extend(l2_tags);
356 tags.sort();
357 tags.dedup();
358
359 Ok(tags)
360 }
361}
362
363pub struct MultiTierBuilder<L1, L2> {
365 l1: Option<L1>,
366 l2: Option<L2>,
367 config: MultiTierConfig,
368}
369
370impl<L1, L2> MultiTierBuilder<L1, L2> {
371 pub fn new() -> Self {
373 Self {
374 l1: None,
375 l2: None,
376 config: MultiTierConfig::default(),
377 }
378 }
379
380 pub fn l1(mut self, backend: L1) -> Self {
382 self.l1 = Some(backend);
383 self
384 }
385
386 pub fn l2(mut self, backend: L2) -> Self {
388 self.l2 = Some(backend);
389 self
390 }
391
392 pub fn promotion_strategy(mut self, strategy: PromotionStrategy) -> Self {
394 self.config.promotion_strategy = strategy;
395 self
396 }
397
398 pub fn promotion_threshold(mut self, threshold: u64) -> Self {
400 self.config.promotion_strategy = PromotionStrategy::HitCount { threshold };
401 self
402 }
403
404 pub fn write_through(mut self, enabled: bool) -> Self {
406 self.config.write_through = enabled;
407 self
408 }
409
410 pub fn max_l1_entry_size(mut self, size: Option<usize>) -> Self {
413 self.config.max_l1_entry_size = size;
414 self
415 }
416
417 pub fn build(self) -> MultiTierBackend<L1, L2> {
419 MultiTierBackend {
420 l1: self.l1.expect("L1 backend is required"),
421 l2: self.l2.expect("L2 backend is required"),
422 config: self.config,
423 key_stats: Arc::new(DashMap::new()),
424 tier_stats: Arc::new(TierStats::default()),
425 }
426 }
427}
428
429impl<L1, L2> Default for MultiTierBuilder<L1, L2> {
430 fn default() -> Self {
431 Self::new()
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::backend::memory::InMemoryBackend;
439 use bytes::Bytes;
440 use http::{StatusCode, Version};
441
442 fn test_entry() -> CacheEntry {
443 CacheEntry::new(
444 StatusCode::OK,
445 Version::HTTP_11,
446 Vec::new(),
447 Bytes::from_static(b"test"),
448 )
449 }
450
451 #[tokio::test]
452 async fn multi_tier_l1_hit() {
453 let l1 = InMemoryBackend::new(100);
454 let l2 = InMemoryBackend::new(1000);
455 let backend = MultiTierBackend::new(l1.clone(), l2);
456
457 l1.set(
459 "key".to_string(),
460 test_entry(),
461 Duration::from_secs(60),
462 Duration::ZERO,
463 )
464 .await
465 .unwrap();
466
467 let result = backend.get("key").await.unwrap();
469 assert!(result.is_some());
470 }
471
472 #[tokio::test]
473 async fn multi_tier_l2_hit_and_promote() {
474 let l1 = InMemoryBackend::new(100);
475 let l2 = InMemoryBackend::new(1000);
476
477 let backend = MultiTierBackend::builder()
478 .l1(l1.clone())
479 .l2(l2.clone())
480 .promotion_threshold(3)
481 .build();
482
483 l2.set(
485 "key".to_string(),
486 test_entry(),
487 Duration::from_secs(60),
488 Duration::ZERO,
489 )
490 .await
491 .unwrap();
492
493 for _ in 0..3 {
495 let result = backend.get("key").await.unwrap();
496 assert!(result.is_some());
497 }
498
499 tokio::time::sleep(Duration::from_millis(50)).await;
501
502 let l1_result = l1.get("key").await.unwrap();
504 assert!(l1_result.is_some());
505 }
506
507 #[tokio::test]
508 async fn multi_tier_set_writes_to_both_tiers() {
509 let l1 = InMemoryBackend::new(100);
510 let l2 = InMemoryBackend::new(1000);
511 let backend = MultiTierBackend::builder()
512 .l1(l1.clone())
513 .l2(l2.clone())
514 .write_through(true)
515 .build();
516
517 backend
518 .set(
519 "key".to_string(),
520 test_entry(),
521 Duration::from_secs(60),
522 Duration::ZERO,
523 )
524 .await
525 .unwrap();
526
527 assert!(l1.get("key").await.unwrap().is_some());
529 assert!(l2.get("key").await.unwrap().is_some());
530 }
531
532 #[tokio::test]
533 async fn multi_tier_invalidate_both_tiers() {
534 let l1 = InMemoryBackend::new(100);
535 let l2 = InMemoryBackend::new(1000);
536 let backend = MultiTierBackend::new(l1.clone(), l2.clone());
537
538 l1.set(
540 "key".to_string(),
541 test_entry(),
542 Duration::from_secs(60),
543 Duration::ZERO,
544 )
545 .await
546 .unwrap();
547 l2.set(
548 "key".to_string(),
549 test_entry(),
550 Duration::from_secs(60),
551 Duration::ZERO,
552 )
553 .await
554 .unwrap();
555
556 backend.invalidate("key").await.unwrap();
558
559 assert!(l1.get("key").await.unwrap().is_none());
561 assert!(l2.get("key").await.unwrap().is_none());
562 }
563
564 #[tokio::test]
565 async fn multi_tier_miss() {
566 let l1 = InMemoryBackend::new(100);
567 let l2 = InMemoryBackend::new(1000);
568 let backend = MultiTierBackend::new(l1, l2);
569
570 let result = backend.get("nonexistent").await.unwrap();
571 assert!(result.is_none());
572 }
573
574 #[tokio::test]
575 async fn promotion_strategy_hit_count() {
576 let strategy = PromotionStrategy::HitCount { threshold: 5 };
577 let l1 = InMemoryBackend::new(100);
578 let l2 = InMemoryBackend::new(1000);
579
580 let backend = MultiTierBackend::builder()
581 .l1(l1.clone())
582 .l2(l2.clone())
583 .promotion_strategy(strategy)
584 .build();
585
586 l2.set(
587 "key".to_string(),
588 test_entry(),
589 Duration::from_secs(60),
590 Duration::ZERO,
591 )
592 .await
593 .unwrap();
594
595 for _ in 0..5 {
597 backend.get("key").await.unwrap();
598 }
599
600 tokio::time::sleep(Duration::from_millis(50)).await;
601
602 assert!(l1.get("key").await.unwrap().is_some());
604 }
605}