tower_http_cache/backend/
memory.rs

1use async_trait::async_trait;
2use moka::future::Cache;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime};
5
6use super::{CacheBackend, CacheEntry, CacheRead};
7use crate::error::CacheError;
8use crate::tags::TagIndex;
9
10/// An in-memory [`CacheBackend`] implementation backed by [`moka`].
11///
12/// The backend is cheap to clone and shares a single underlying cache.
13#[derive(Clone)]
14pub struct InMemoryBackend {
15    cache: Cache<String, StoredEntry>,
16    tag_index: Arc<TagIndex>,
17}
18
19#[derive(Clone)]
20struct StoredEntry {
21    entry: CacheEntry,
22    expires_at: SystemTime,
23    stale_until: SystemTime,
24}
25
26impl InMemoryBackend {
27    /// Creates a new in-memory cache with the provided `max_capacity`.
28    ///
29    /// The capacity is expressed in number of cached entries, not bytes.
30    pub fn new(max_capacity: u64) -> Self {
31        let cache = Cache::builder().max_capacity(max_capacity).build();
32        Self {
33            cache,
34            tag_index: Arc::new(TagIndex::new()),
35        }
36    }
37}
38
39#[async_trait]
40impl CacheBackend for InMemoryBackend {
41    async fn get(&self, key: &str) -> Result<Option<CacheRead>, CacheError> {
42        if let Some(stored) = self.cache.get(key).await {
43            let now = SystemTime::now();
44            if now > stored.stale_until {
45                self.cache.invalidate(key).await;
46                return Ok(None);
47            }
48
49            Ok(Some(CacheRead {
50                entry: stored.entry.clone(),
51                expires_at: Some(stored.expires_at),
52                stale_until: Some(stored.stale_until),
53            }))
54        } else {
55            Ok(None)
56        }
57    }
58
59    async fn set(
60        &self,
61        key: String,
62        entry: CacheEntry,
63        ttl: Duration,
64        stale_for: Duration,
65    ) -> Result<(), CacheError> {
66        if ttl.is_zero() {
67            return Ok(());
68        }
69
70        let now = SystemTime::now();
71        let expires_at = now + ttl;
72        let stale_until = expires_at + stale_for;
73
74        // Index tags if present
75        if let Some(ref tags) = entry.tags {
76            if !tags.is_empty() {
77                self.tag_index.index(key.clone(), tags.clone());
78            }
79        }
80
81        let stored = StoredEntry {
82            entry,
83            expires_at,
84            stale_until,
85        };
86        self.cache.insert(key, stored).await;
87        Ok(())
88    }
89
90    async fn invalidate(&self, key: &str) -> Result<(), CacheError> {
91        self.cache.invalidate(key).await;
92        self.tag_index.remove(key);
93        Ok(())
94    }
95
96    async fn get_keys_by_tag(&self, tag: &str) -> Result<Vec<String>, CacheError> {
97        Ok(self.tag_index.get_keys_by_tag(tag))
98    }
99
100    async fn list_tags(&self) -> Result<Vec<String>, CacheError> {
101        Ok(self.tag_index.list_tags())
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::backend::CacheEntry;
109    use bytes::Bytes;
110    use http::{StatusCode, Version};
111    use tokio::time::{sleep, Duration};
112
113    fn entry_with_body(body: &'static [u8]) -> CacheEntry {
114        CacheEntry::new(
115            StatusCode::OK,
116            Version::HTTP_11,
117            Vec::new(),
118            Bytes::from_static(body),
119        )
120    }
121
122    #[tokio::test]
123    async fn set_and_get_returns_cached_entry() {
124        let backend = InMemoryBackend::new(16);
125        let entry = entry_with_body(b"alpha");
126
127        backend
128            .set(
129                "key".into(),
130                entry.clone(),
131                Duration::from_secs(1),
132                Duration::from_secs(1),
133            )
134            .await
135            .expect("set succeeds");
136
137        let read = backend.get("key").await.expect("get succeeds");
138        let cached = read.expect("entry present");
139
140        assert_eq!(cached.entry.body, entry.body);
141        assert!(cached.expires_at.is_some());
142        assert!(cached.stale_until.is_some());
143    }
144
145    #[tokio::test]
146    async fn entry_invalidated_after_stale_window() {
147        let backend = InMemoryBackend::new(16);
148
149        backend
150            .set(
151                "key".into(),
152                entry_with_body(b"stale"),
153                Duration::from_millis(20),
154                Duration::from_millis(30),
155            )
156            .await
157            .expect("set succeeds");
158
159        sleep(Duration::from_millis(35)).await;
160        let read = backend.get("key").await.expect("get succeeds");
161        assert!(read.is_some(), "entry available during stale window");
162
163        sleep(Duration::from_millis(40)).await;
164        let read = backend.get("key").await.expect("get succeeds");
165        assert!(read.is_none(), "entry removed after stale window");
166    }
167}