tower_http_cache/backend/
memory.rs1use async_trait::async_trait;
2use moka::future::Cache;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime};
5
6use super::{CacheBackend, CacheEntry, CacheRead};
7use crate::error::CacheError;
8use crate::tags::TagIndex;
9
10#[derive(Clone)]
14pub struct InMemoryBackend {
15 cache: Cache<String, StoredEntry>,
16 tag_index: Arc<TagIndex>,
17}
18
19#[derive(Clone)]
20struct StoredEntry {
21 entry: CacheEntry,
22 expires_at: SystemTime,
23 stale_until: SystemTime,
24}
25
26impl InMemoryBackend {
27 pub fn new(max_capacity: u64) -> Self {
31 let cache = Cache::builder().max_capacity(max_capacity).build();
32 Self {
33 cache,
34 tag_index: Arc::new(TagIndex::new()),
35 }
36 }
37}
38
39#[async_trait]
40impl CacheBackend for InMemoryBackend {
41 async fn get(&self, key: &str) -> Result<Option<CacheRead>, CacheError> {
42 if let Some(stored) = self.cache.get(key).await {
43 let now = SystemTime::now();
44 if now > stored.stale_until {
45 self.cache.invalidate(key).await;
46 return Ok(None);
47 }
48
49 Ok(Some(CacheRead {
50 entry: stored.entry.clone(),
51 expires_at: Some(stored.expires_at),
52 stale_until: Some(stored.stale_until),
53 }))
54 } else {
55 Ok(None)
56 }
57 }
58
59 async fn set(
60 &self,
61 key: String,
62 entry: CacheEntry,
63 ttl: Duration,
64 stale_for: Duration,
65 ) -> Result<(), CacheError> {
66 if ttl.is_zero() {
67 return Ok(());
68 }
69
70 let now = SystemTime::now();
71 let expires_at = now + ttl;
72 let stale_until = expires_at + stale_for;
73
74 if let Some(ref tags) = entry.tags {
76 if !tags.is_empty() {
77 self.tag_index.index(key.clone(), tags.clone());
78 }
79 }
80
81 let stored = StoredEntry {
82 entry,
83 expires_at,
84 stale_until,
85 };
86 self.cache.insert(key, stored).await;
87 Ok(())
88 }
89
90 async fn invalidate(&self, key: &str) -> Result<(), CacheError> {
91 self.cache.invalidate(key).await;
92 self.tag_index.remove(key);
93 Ok(())
94 }
95
96 async fn get_keys_by_tag(&self, tag: &str) -> Result<Vec<String>, CacheError> {
97 Ok(self.tag_index.get_keys_by_tag(tag))
98 }
99
100 async fn list_tags(&self) -> Result<Vec<String>, CacheError> {
101 Ok(self.tag_index.list_tags())
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use crate::backend::CacheEntry;
109 use bytes::Bytes;
110 use http::{StatusCode, Version};
111 use tokio::time::{sleep, Duration};
112
113 fn entry_with_body(body: &'static [u8]) -> CacheEntry {
114 CacheEntry::new(
115 StatusCode::OK,
116 Version::HTTP_11,
117 Vec::new(),
118 Bytes::from_static(body),
119 )
120 }
121
122 #[tokio::test]
123 async fn set_and_get_returns_cached_entry() {
124 let backend = InMemoryBackend::new(16);
125 let entry = entry_with_body(b"alpha");
126
127 backend
128 .set(
129 "key".into(),
130 entry.clone(),
131 Duration::from_secs(1),
132 Duration::from_secs(1),
133 )
134 .await
135 .expect("set succeeds");
136
137 let read = backend.get("key").await.expect("get succeeds");
138 let cached = read.expect("entry present");
139
140 assert_eq!(cached.entry.body, entry.body);
141 assert!(cached.expires_at.is_some());
142 assert!(cached.stale_until.is_some());
143 }
144
145 #[tokio::test]
146 async fn entry_invalidated_after_stale_window() {
147 let backend = InMemoryBackend::new(16);
148
149 backend
150 .set(
151 "key".into(),
152 entry_with_body(b"stale"),
153 Duration::from_millis(20),
154 Duration::from_millis(30),
155 )
156 .await
157 .expect("set succeeds");
158
159 sleep(Duration::from_millis(35)).await;
160 let read = backend.get("key").await.expect("get succeeds");
161 assert!(read.is_some(), "entry available during stale window");
162
163 sleep(Duration::from_millis(40)).await;
164 let read = backend.get("key").await.expect("get succeeds");
165 assert!(read.is_none(), "entry removed after stale window");
166 }
167}