tork_core/cache/
handle.rs1use std::future::Future;
4use std::sync::Arc;
5use std::time::Duration;
6
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9
10use crate::error::{Error, Result};
11use crate::extract::{FromRequest, RequestContext};
12
13use super::memory::MemoryStore;
14use super::store::CacheStore;
15
16#[derive(Clone)]
34pub struct Cache {
35 store: Arc<dyn CacheStore>,
36 default_ttl: Option<Duration>,
37}
38
39impl Cache {
40 pub fn new(store: impl CacheStore) -> Self {
42 Self {
43 store: Arc::new(store),
44 default_ttl: None,
45 }
46 }
47
48 pub fn in_memory() -> Self {
50 Self::new(MemoryStore::new())
51 }
52
53 #[cfg(feature = "redis")]
59 pub async fn redis(url: &str) -> Result<Self> {
60 Ok(Self::new(super::RedisStore::connect(url).await?))
61 }
62
63 #[cfg(feature = "redis")]
69 pub fn from_redis(redis: &crate::Redis) -> Self {
70 Self::new(super::RedisStore::from_redis(
71 redis,
72 super::RedisStore::default_prefix(),
73 ))
74 }
75
76 pub fn default_ttl(mut self, ttl: Duration) -> Self {
81 self.default_ttl = normalize_ttl(Some(ttl));
82 self
83 }
84
85 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
87 match self.store.get(key).await? {
88 Some(bytes) => {
89 let value = serde_json::from_slice(&bytes).map_err(|error| {
90 Error::internal(format!("cache value could not be deserialized: {error}"))
91 })?;
92 Ok(Some(value))
93 }
94 None => Ok(None),
95 }
96 }
97
98 pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
100 self.write(key, value, self.default_ttl).await
101 }
102
103 pub async fn set_ttl<T: Serialize>(&self, key: &str, value: &T, ttl: Duration) -> Result<()> {
107 self.write(key, value, normalize_ttl(Some(ttl))).await
108 }
109
110 pub async fn get_or_set<T, F, Fut>(
116 &self,
117 key: &str,
118 ttl: Option<Duration>,
119 init: F,
120 ) -> Result<T>
121 where
122 T: Serialize + DeserializeOwned,
123 F: FnOnce() -> Fut,
124 Fut: Future<Output = Result<T>>,
125 {
126 if let Some(found) = self.get::<T>(key).await? {
127 return Ok(found);
128 }
129 let value = init().await?;
130 let ttl = match ttl {
131 Some(ttl) => normalize_ttl(Some(ttl)),
132 None => self.default_ttl,
133 };
134 self.write(key, &value, ttl).await?;
135 Ok(value)
136 }
137
138 pub async fn delete(&self, key: &str) -> Result<()> {
140 self.store.delete(key).await
141 }
142
143 pub async fn clear(&self) -> Result<()> {
145 self.store.clear().await
146 }
147
148 async fn write<T: Serialize>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()> {
150 let bytes = serde_json::to_vec(value).map_err(|error| {
151 Error::internal(format!("cache value could not be serialized: {error}"))
152 })?;
153 self.store.set(key.to_owned(), bytes, ttl).await
154 }
155}
156
157fn normalize_ttl(ttl: Option<Duration>) -> Option<Duration> {
159 match ttl {
160 Some(ttl) if ttl.is_zero() => None,
161 other => other,
162 }
163}
164
165impl FromRequest for Cache {
166 fn from_request(ctx: &RequestContext) -> impl Future<Output = Result<Self>> + Send {
167 let resolved = ctx
168 .state()
169 .get::<Cache>()
170 .map(|cache| (*cache).clone())
171 .ok_or_else(|| {
172 Error::internal("cache is not configured; call `App::cache(...)` to enable it")
173 });
174 async move { resolved }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 use std::sync::atomic::{AtomicUsize, Ordering};
183
184 use serde::Deserialize;
185
186 #[derive(Debug, PartialEq, Serialize, Deserialize)]
187 struct User {
188 id: i64,
189 name: String,
190 }
191
192 #[tokio::test]
193 async fn round_trips_a_typed_value() {
194 let cache = Cache::in_memory();
195 let user = User {
196 id: 1,
197 name: "alice".into(),
198 };
199 cache.set("user:1", &user).await.unwrap();
200
201 let stored: Option<User> = cache.get("user:1").await.unwrap();
202 assert_eq!(stored, Some(user));
203 }
204
205 #[tokio::test]
206 async fn a_missing_key_is_none() {
207 let cache = Cache::in_memory();
208 let stored: Option<String> = cache.get("absent").await.unwrap();
209 assert_eq!(stored, None);
210 }
211
212 #[tokio::test]
213 async fn an_entry_expires_after_its_ttl() {
214 let cache = Cache::in_memory();
215 cache
216 .set_ttl("k", &"v", Duration::from_millis(50))
217 .await
218 .unwrap();
219
220 assert_eq!(
221 cache.get::<String>("k").await.unwrap().as_deref(),
222 Some("v")
223 );
224 tokio::time::sleep(Duration::from_millis(120)).await;
225 assert_eq!(cache.get::<String>("k").await.unwrap(), None);
226 }
227
228 #[tokio::test]
229 async fn a_zero_ttl_never_expires() {
230 let cache = Cache::in_memory();
231 cache.set_ttl("k", &"v", Duration::ZERO).await.unwrap();
232
233 tokio::time::sleep(Duration::from_millis(80)).await;
234 assert_eq!(
235 cache.get::<String>("k").await.unwrap().as_deref(),
236 Some("v")
237 );
238 }
239
240 #[tokio::test]
241 async fn default_ttl_applies_to_plain_set() {
242 let cache = Cache::in_memory().default_ttl(Duration::from_millis(50));
243 cache.set("k", &"v").await.unwrap();
244
245 tokio::time::sleep(Duration::from_millis(120)).await;
246 assert_eq!(cache.get::<String>("k").await.unwrap(), None);
247 }
248
249 #[tokio::test]
250 async fn get_or_set_computes_once_then_hits_the_cache() {
251 let cache = Cache::in_memory();
252 let calls = AtomicUsize::new(0);
253
254 let compute = || async {
255 calls.fetch_add(1, Ordering::SeqCst);
256 Ok::<_, Error>(User {
257 id: 7,
258 name: "bob".into(),
259 })
260 };
261
262 let first = cache.get_or_set("user:7", None, compute).await.unwrap();
263 let second = cache.get_or_set("user:7", None, compute).await.unwrap();
264
265 assert_eq!(first, second);
266 assert_eq!(calls.load(Ordering::SeqCst), 1, "init runs only on a miss");
267 }
268
269 #[tokio::test]
270 async fn delete_and_clear_remove_entries() {
271 let cache = Cache::in_memory();
272 cache.set("a", &1).await.unwrap();
273 cache.set("b", &2).await.unwrap();
274
275 cache.delete("a").await.unwrap();
276 assert_eq!(cache.get::<i32>("a").await.unwrap(), None);
277 assert_eq!(cache.get::<i32>("b").await.unwrap(), Some(2));
278
279 cache.clear().await.unwrap();
280 assert_eq!(cache.get::<i32>("b").await.unwrap(), None);
281 }
282}