tower_http_cache/backend/
multi_tier.rs

1//! Multi-tier caching backend with automatic promotion.
2//!
3//! This module implements a two-tier caching architecture where frequently
4//! accessed entries are automatically promoted from a slower L2 cache to
5//! a faster L1 cache based on configurable promotion strategies.
6
7use async_trait::async_trait;
8use dashmap::DashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use super::{CacheBackend, CacheEntry, CacheRead};
14use crate::error::CacheError;
15
16/// Strategy for promoting entries from L2 to L1.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum PromotionStrategy {
19    /// Promote after a fixed number of hits
20    HitCount { threshold: u64 },
21
22    /// Promote based on hit rate over time window
23    HitRate { threshold_per_minute: u64 },
24}
25
26impl Default for PromotionStrategy {
27    fn default() -> Self {
28        Self::HitCount { threshold: 3 }
29    }
30}
31
32/// Statistics for a cache tier.
33#[derive(Debug, Clone, Default)]
34pub struct TierStats {
35    pub l1_hits: u64,
36    pub l2_hits: u64,
37    pub misses: u64,
38    pub promotions: u64,
39}
40
41/// Configuration for multi-tier caching.
42#[derive(Debug, Clone)]
43pub struct MultiTierConfig {
44    /// Strategy for promoting entries from L2 to L1
45    pub promotion_strategy: PromotionStrategy,
46
47    /// Whether to write to both tiers on set (true) or L2 only (false)
48    pub write_through: bool,
49
50    /// Don't store entries larger than this in L1 (default: 256KB)
51    /// Large entries are only stored in L2 to prevent L1 pollution
52    pub max_l1_entry_size: Option<usize>,
53}
54
55impl Default for MultiTierConfig {
56    fn default() -> Self {
57        Self {
58            promotion_strategy: PromotionStrategy::default(),
59            write_through: true,
60            max_l1_entry_size: Some(256 * 1024), // 256KB
61        }
62    }
63}
64
65/// Per-key statistics for promotion tracking.
66struct KeyStats {
67    l2_hits: AtomicU64,
68}
69
70impl KeyStats {
71    fn new() -> Self {
72        Self {
73            l2_hits: AtomicU64::new(0),
74        }
75    }
76
77    fn record_hit(&self) -> u64 {
78        self.l2_hits.fetch_add(1, Ordering::Relaxed) + 1
79    }
80
81    fn reset(&self) {
82        self.l2_hits.store(0, Ordering::Relaxed);
83    }
84
85    fn hits(&self) -> u64 {
86        self.l2_hits.load(Ordering::Relaxed)
87    }
88}
89
90/// Multi-tier cache backend combining fast L1 and persistent L2.
91///
92/// The multi-tier backend automatically promotes frequently accessed entries
93/// from L2 to L1 based on the configured promotion strategy.
94#[derive(Clone)]
95pub struct MultiTierBackend<L1, L2> {
96    l1: L1,
97    l2: L2,
98    config: MultiTierConfig,
99    key_stats: Arc<DashMap<String, Arc<KeyStats>>>,
100    tier_stats: Arc<TierStats>,
101}
102
103impl<L1, L2> MultiTierBackend<L1, L2>
104where
105    L1: CacheBackend,
106    L2: CacheBackend,
107{
108    /// Creates a new multi-tier backend with default configuration.
109    pub fn new(l1: L1, l2: L2) -> Self {
110        Self {
111            l1,
112            l2,
113            config: MultiTierConfig::default(),
114            key_stats: Arc::new(DashMap::new()),
115            tier_stats: Arc::new(TierStats::default()),
116        }
117    }
118
119    /// Creates a builder for configuring the multi-tier backend.
120    pub fn builder() -> MultiTierBuilder<L1, L2> {
121        MultiTierBuilder::new()
122    }
123
124    /// Returns a reference to the L1 backend.
125    pub fn l1(&self) -> &L1 {
126        &self.l1
127    }
128
129    /// Returns a reference to the L2 backend.
130    pub fn l2(&self) -> &L2 {
131        &self.l2
132    }
133
134    /// Returns a reference to the current tier statistics.
135    pub fn stats(&self) -> &TierStats {
136        &self.tier_stats
137    }
138
139    /// Checks if an entry should be promoted from L2 to L1.
140    fn should_promote(&self, key: &str) -> bool {
141        let stats = self
142            .key_stats
143            .entry(key.to_string())
144            .or_insert_with(|| Arc::new(KeyStats::new()));
145
146        match self.config.promotion_strategy {
147            PromotionStrategy::HitCount { threshold } => stats.hits() >= threshold,
148            PromotionStrategy::HitRate {
149                threshold_per_minute: _,
150            } => {
151                // For simplicity, use hit count for now
152                // A full implementation would track timestamps
153                stats.hits() >= 3
154            }
155        }
156    }
157
158    /// Records a hit on a key and returns the hit count.
159    fn record_hit(&self, key: &str) -> u64 {
160        self.key_stats
161            .entry(key.to_string())
162            .or_insert_with(|| Arc::new(KeyStats::new()))
163            .record_hit()
164    }
165
166    /// Promotes an entry from L2 to L1.
167    #[allow(dead_code)]
168    async fn promote(
169        &self,
170        key: &str,
171        entry: CacheEntry,
172        ttl: Duration,
173        stale_for: Duration,
174    ) -> Result<(), CacheError> {
175        // Store in L1
176        self.l1.set(key.to_string(), entry, ttl, stale_for).await?;
177
178        // Reset promotion counter
179        if let Some(stats) = self.key_stats.get(key) {
180            stats.reset();
181        }
182
183        Ok(())
184    }
185}
186
187#[async_trait]
188impl<L1, L2> CacheBackend for MultiTierBackend<L1, L2>
189where
190    L1: CacheBackend,
191    L2: CacheBackend,
192{
193    async fn get(&self, key: &str) -> Result<Option<CacheRead>, CacheError> {
194        // Try L1 first
195        if let Some(entry) = self.l1.get(key).await? {
196            #[cfg(feature = "metrics")]
197            metrics::counter!("tower_http_cache.tier.l1_hit").increment(1);
198            return Ok(Some(entry));
199        }
200
201        // Try L2
202        if let Some(read) = self.l2.get(key).await? {
203            #[cfg(feature = "metrics")]
204            metrics::counter!("tower_http_cache.tier.l2_hit").increment(1);
205
206            // Record hit and check for promotion
207            self.record_hit(key);
208
209            if self.should_promote(key) {
210                let entry_size = read.entry.body.len();
211
212                // Check if entry is small enough for L1
213                let should_promote_l1 = if let Some(max_size) = self.config.max_l1_entry_size {
214                    entry_size <= max_size
215                } else {
216                    true
217                };
218
219                if should_promote_l1 {
220                    #[cfg(feature = "metrics")]
221                    metrics::counter!("tower_http_cache.tier.promoted").increment(1);
222
223                    // Calculate remaining TTL for promotion
224                    let ttl = if let Some(expires_at) = read.expires_at {
225                        expires_at
226                            .duration_since(std::time::SystemTime::now())
227                            .unwrap_or(Duration::from_secs(60))
228                    } else {
229                        Duration::from_secs(60)
230                    };
231
232                    let stale_for = if let (Some(stale_until), Some(expires_at)) =
233                        (read.stale_until, read.expires_at)
234                    {
235                        stale_until.duration_since(expires_at).unwrap_or_default()
236                    } else {
237                        Duration::ZERO
238                    };
239
240                    // Promote asynchronously (best effort)
241                    let entry = read.entry.clone();
242                    let key = key.to_string();
243                    let l1 = self.l1.clone();
244                    let key_stats = self.key_stats.clone();
245
246                    tokio::spawn(async move {
247                        let _ = l1.set(key.clone(), entry, ttl, stale_for).await;
248                        if let Some(stats) = key_stats.get(&key) {
249                            stats.reset();
250                        }
251                    });
252                } else {
253                    #[cfg(feature = "metrics")]
254                    metrics::counter!("tower_http_cache.tier.promotion_skipped_large").increment(1);
255
256                    #[cfg(feature = "tracing")]
257                    tracing::debug!(
258                        key = %key,
259                        size = entry_size,
260                        max_l1_size = ?self.config.max_l1_entry_size,
261                        "skipping promotion for large entry"
262                    );
263                }
264            }
265
266            return Ok(Some(read));
267        }
268
269        Ok(None)
270    }
271
272    async fn set(
273        &self,
274        key: String,
275        entry: CacheEntry,
276        ttl: Duration,
277        stale_for: Duration,
278    ) -> Result<(), CacheError> {
279        let entry_size = entry.body.len();
280
281        // Always write to L2
282        self.l2
283            .set(key.clone(), entry.clone(), ttl, stale_for)
284            .await?;
285
286        // Optionally write to L1 if write-through is enabled and size is acceptable
287        if self.config.write_through {
288            let should_write_l1 = if let Some(max_size) = self.config.max_l1_entry_size {
289                if entry_size <= max_size {
290                    true
291                } else {
292                    #[cfg(feature = "metrics")]
293                    metrics::counter!("tower_http_cache.tier.l1_skipped_large").increment(1);
294
295                    #[cfg(feature = "tracing")]
296                    tracing::debug!(
297                        key = %key,
298                        size = entry_size,
299                        max_l1_size = max_size,
300                        "skipping L1 write for large entry"
301                    );
302
303                    false
304                }
305            } else {
306                true
307            };
308
309            if should_write_l1 {
310                let _ = self.l1.set(key.clone(), entry, ttl, stale_for).await;
311            }
312        }
313
314        Ok(())
315    }
316
317    async fn invalidate(&self, key: &str) -> Result<(), CacheError> {
318        // Invalidate both tiers
319        let l1_result = self.l1.invalidate(key).await;
320        let l2_result = self.l2.invalidate(key).await;
321
322        // Remove stats
323        self.key_stats.remove(key);
324
325        // Return first error if any
326        l1_result.and(l2_result)
327    }
328
329    async fn get_keys_by_tag(&self, tag: &str) -> Result<Vec<String>, CacheError> {
330        // Query both tiers and merge results
331        let mut keys = self.l1.get_keys_by_tag(tag).await?;
332        let l2_keys = self.l2.get_keys_by_tag(tag).await?;
333
334        // Deduplicate
335        keys.extend(l2_keys);
336        keys.sort();
337        keys.dedup();
338
339        Ok(keys)
340    }
341
342    async fn invalidate_by_tag(&self, tag: &str) -> Result<usize, CacheError> {
343        // Invalidate in both tiers
344        let l1_count = self.l1.invalidate_by_tag(tag).await?;
345        let l2_count = self.l2.invalidate_by_tag(tag).await?;
346
347        Ok(l1_count + l2_count)
348    }
349
350    async fn list_tags(&self) -> Result<Vec<String>, CacheError> {
351        // Merge tags from both tiers
352        let mut tags = self.l1.list_tags().await?;
353        let l2_tags = self.l2.list_tags().await?;
354
355        tags.extend(l2_tags);
356        tags.sort();
357        tags.dedup();
358
359        Ok(tags)
360    }
361}
362
363/// Builder for configuring a multi-tier backend.
364pub struct MultiTierBuilder<L1, L2> {
365    l1: Option<L1>,
366    l2: Option<L2>,
367    config: MultiTierConfig,
368}
369
370impl<L1, L2> MultiTierBuilder<L1, L2> {
371    /// Creates a new builder.
372    pub fn new() -> Self {
373        Self {
374            l1: None,
375            l2: None,
376            config: MultiTierConfig::default(),
377        }
378    }
379
380    /// Sets the L1 (fast) cache backend.
381    pub fn l1(mut self, backend: L1) -> Self {
382        self.l1 = Some(backend);
383        self
384    }
385
386    /// Sets the L2 (persistent) cache backend.
387    pub fn l2(mut self, backend: L2) -> Self {
388        self.l2 = Some(backend);
389        self
390    }
391
392    /// Sets the promotion strategy.
393    pub fn promotion_strategy(mut self, strategy: PromotionStrategy) -> Self {
394        self.config.promotion_strategy = strategy;
395        self
396    }
397
398    /// Sets the promotion threshold for hit-count based promotion.
399    pub fn promotion_threshold(mut self, threshold: u64) -> Self {
400        self.config.promotion_strategy = PromotionStrategy::HitCount { threshold };
401        self
402    }
403
404    /// Enables or disables write-through to L1.
405    pub fn write_through(mut self, enabled: bool) -> Self {
406        self.config.write_through = enabled;
407        self
408    }
409
410    /// Sets the maximum entry size for L1 cache.
411    /// Entries larger than this will only be stored in L2.
412    pub fn max_l1_entry_size(mut self, size: Option<usize>) -> Self {
413        self.config.max_l1_entry_size = size;
414        self
415    }
416
417    /// Builds the multi-tier backend.
418    pub fn build(self) -> MultiTierBackend<L1, L2> {
419        MultiTierBackend {
420            l1: self.l1.expect("L1 backend is required"),
421            l2: self.l2.expect("L2 backend is required"),
422            config: self.config,
423            key_stats: Arc::new(DashMap::new()),
424            tier_stats: Arc::new(TierStats::default()),
425        }
426    }
427}
428
429impl<L1, L2> Default for MultiTierBuilder<L1, L2> {
430    fn default() -> Self {
431        Self::new()
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::backend::memory::InMemoryBackend;
439    use bytes::Bytes;
440    use http::{StatusCode, Version};
441
442    fn test_entry() -> CacheEntry {
443        CacheEntry::new(
444            StatusCode::OK,
445            Version::HTTP_11,
446            Vec::new(),
447            Bytes::from_static(b"test"),
448        )
449    }
450
451    #[tokio::test]
452    async fn multi_tier_l1_hit() {
453        let l1 = InMemoryBackend::new(100);
454        let l2 = InMemoryBackend::new(1000);
455        let backend = MultiTierBackend::new(l1.clone(), l2);
456
457        // Store in L1
458        l1.set(
459            "key".to_string(),
460            test_entry(),
461            Duration::from_secs(60),
462            Duration::ZERO,
463        )
464        .await
465        .unwrap();
466
467        // Should hit L1
468        let result = backend.get("key").await.unwrap();
469        assert!(result.is_some());
470    }
471
472    #[tokio::test]
473    async fn multi_tier_l2_hit_and_promote() {
474        let l1 = InMemoryBackend::new(100);
475        let l2 = InMemoryBackend::new(1000);
476
477        let backend = MultiTierBackend::builder()
478            .l1(l1.clone())
479            .l2(l2.clone())
480            .promotion_threshold(3)
481            .build();
482
483        // Store in L2 only
484        l2.set(
485            "key".to_string(),
486            test_entry(),
487            Duration::from_secs(60),
488            Duration::ZERO,
489        )
490        .await
491        .unwrap();
492
493        // First few hits should be from L2
494        for _ in 0..3 {
495            let result = backend.get("key").await.unwrap();
496            assert!(result.is_some());
497        }
498
499        // Give promotion task time to complete
500        tokio::time::sleep(Duration::from_millis(50)).await;
501
502        // After promotion threshold, should be in L1
503        let l1_result = l1.get("key").await.unwrap();
504        assert!(l1_result.is_some());
505    }
506
507    #[tokio::test]
508    async fn multi_tier_set_writes_to_both_tiers() {
509        let l1 = InMemoryBackend::new(100);
510        let l2 = InMemoryBackend::new(1000);
511        let backend = MultiTierBackend::builder()
512            .l1(l1.clone())
513            .l2(l2.clone())
514            .write_through(true)
515            .build();
516
517        backend
518            .set(
519                "key".to_string(),
520                test_entry(),
521                Duration::from_secs(60),
522                Duration::ZERO,
523            )
524            .await
525            .unwrap();
526
527        // Should be in both L1 and L2
528        assert!(l1.get("key").await.unwrap().is_some());
529        assert!(l2.get("key").await.unwrap().is_some());
530    }
531
532    #[tokio::test]
533    async fn multi_tier_invalidate_both_tiers() {
534        let l1 = InMemoryBackend::new(100);
535        let l2 = InMemoryBackend::new(1000);
536        let backend = MultiTierBackend::new(l1.clone(), l2.clone());
537
538        // Store in both
539        l1.set(
540            "key".to_string(),
541            test_entry(),
542            Duration::from_secs(60),
543            Duration::ZERO,
544        )
545        .await
546        .unwrap();
547        l2.set(
548            "key".to_string(),
549            test_entry(),
550            Duration::from_secs(60),
551            Duration::ZERO,
552        )
553        .await
554        .unwrap();
555
556        // Invalidate through multi-tier
557        backend.invalidate("key").await.unwrap();
558
559        // Should be removed from both
560        assert!(l1.get("key").await.unwrap().is_none());
561        assert!(l2.get("key").await.unwrap().is_none());
562    }
563
564    #[tokio::test]
565    async fn multi_tier_miss() {
566        let l1 = InMemoryBackend::new(100);
567        let l2 = InMemoryBackend::new(1000);
568        let backend = MultiTierBackend::new(l1, l2);
569
570        let result = backend.get("nonexistent").await.unwrap();
571        assert!(result.is_none());
572    }
573
574    #[tokio::test]
575    async fn promotion_strategy_hit_count() {
576        let strategy = PromotionStrategy::HitCount { threshold: 5 };
577        let l1 = InMemoryBackend::new(100);
578        let l2 = InMemoryBackend::new(1000);
579
580        let backend = MultiTierBackend::builder()
581            .l1(l1.clone())
582            .l2(l2.clone())
583            .promotion_strategy(strategy)
584            .build();
585
586        l2.set(
587            "key".to_string(),
588            test_entry(),
589            Duration::from_secs(60),
590            Duration::ZERO,
591        )
592        .await
593        .unwrap();
594
595        // Hit 5 times to trigger promotion
596        for _ in 0..5 {
597            backend.get("key").await.unwrap();
598        }
599
600        tokio::time::sleep(Duration::from_millis(50)).await;
601
602        // Should be promoted to L1
603        assert!(l1.get("key").await.unwrap().is_some());
604    }
605}