Skip to main content

tork_core/cache/
handle.rs

1//! The typed cache handle.
2
3use std::future::Future;
4use std::sync::Arc;
5use std::time::Duration;
6
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9
10use crate::error::{Error, Result};
11use crate::extract::{FromRequest, RequestContext};
12
13use super::memory::MemoryStore;
14use super::store::CacheStore;
15
16/// A cheap-to-clone handle over a [`CacheStore`], with a typed, ergonomic API.
17///
18/// Values are serialized to JSON before being stored and deserialized on the way
19/// out, so any `serde` type can be cached. Cloning a `Cache` is cheap (it shares
20/// the underlying store), so it is held as an injected resource.
21///
22/// # Examples
23///
24/// ```no_run
25/// # use tork_core::Cache;
26/// # async fn run(cache: Cache) -> tork_core::Result<()> {
27/// cache.set("greeting", &"hello").await?;
28/// let greeting: Option<String> = cache.get("greeting").await?;
29/// assert_eq!(greeting.as_deref(), Some("hello"));
30/// # Ok(())
31/// # }
32/// ```
33#[derive(Clone)]
34pub struct Cache {
35    store: Arc<dyn CacheStore>,
36    default_ttl: Option<Duration>,
37}
38
39impl Cache {
40    /// Builds a cache over a custom [`CacheStore`].
41    pub fn new(store: impl CacheStore) -> Self {
42        Self {
43            store: Arc::new(store),
44            default_ttl: None,
45        }
46    }
47
48    /// Builds a cache over the default in-memory store ([`MemoryStore`]).
49    pub fn in_memory() -> Self {
50        Self::new(MemoryStore::new())
51    }
52
53    /// Builds a cache over a Redis store at `url` (for example
54    /// `redis://127.0.0.1:6379`), opening its own connection, for sharing the cache
55    /// across instances.
56    ///
57    /// Available with the `redis` feature.
58    #[cfg(feature = "redis")]
59    pub async fn redis(url: &str) -> Result<Self> {
60        Ok(Self::new(super::RedisStore::connect(url).await?))
61    }
62
63    /// Builds a cache over an injected [`Redis`](crate::Redis) connection, so the
64    /// cache shares one connection pool with raw Redis access, a rate limiter, and
65    /// anything else built on the same handle.
66    ///
67    /// Available with the `redis` feature.
68    #[cfg(feature = "redis")]
69    pub fn from_redis(redis: &crate::Redis) -> Self {
70        Self::new(super::RedisStore::from_redis(
71            redis,
72            super::RedisStore::default_prefix(),
73        ))
74    }
75
76    /// Sets the TTL applied by [`set`](Cache::set) when no explicit TTL is given.
77    ///
78    /// Without this, `set` stores entries with no expiry (they live until evicted
79    /// by the store's capacity limit).
80    pub fn default_ttl(mut self, ttl: Duration) -> Self {
81        self.default_ttl = normalize_ttl(Some(ttl));
82        self
83    }
84
85    /// Returns the value stored under `key`, or `None` if absent or expired.
86    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
87        match self.store.get(key).await? {
88            Some(bytes) => {
89                let value = serde_json::from_slice(&bytes).map_err(|error| {
90                    Error::internal(format!("cache value could not be deserialized: {error}"))
91                })?;
92                Ok(Some(value))
93            }
94            None => Ok(None),
95        }
96    }
97
98    /// Stores `value` under `key` using the cache's default TTL.
99    pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
100        self.write(key, value, self.default_ttl).await
101    }
102
103    /// Stores `value` under `key`, expiring it after `ttl`.
104    ///
105    /// A zero `ttl` means the entry never expires.
106    pub async fn set_ttl<T: Serialize>(&self, key: &str, value: &T, ttl: Duration) -> Result<()> {
107        self.write(key, value, normalize_ttl(Some(ttl))).await
108    }
109
110    /// Returns the cached value under `key`, or computes it with `init`, stores it
111    /// (with `ttl`, falling back to the default TTL), and returns it.
112    ///
113    /// This is the cache-aside pattern in one call: a hit returns immediately
114    /// without running `init`; a miss runs `init` once and caches the result.
115    pub async fn get_or_set<T, F, Fut>(
116        &self,
117        key: &str,
118        ttl: Option<Duration>,
119        init: F,
120    ) -> Result<T>
121    where
122        T: Serialize + DeserializeOwned,
123        F: FnOnce() -> Fut,
124        Fut: Future<Output = Result<T>>,
125    {
126        if let Some(found) = self.get::<T>(key).await? {
127            return Ok(found);
128        }
129        let value = init().await?;
130        let ttl = match ttl {
131            Some(ttl) => normalize_ttl(Some(ttl)),
132            None => self.default_ttl,
133        };
134        self.write(key, &value, ttl).await?;
135        Ok(value)
136    }
137
138    /// Removes the entry under `key`, if any.
139    pub async fn delete(&self, key: &str) -> Result<()> {
140        self.store.delete(key).await
141    }
142
143    /// Removes every entry from the cache.
144    pub async fn clear(&self) -> Result<()> {
145        self.store.clear().await
146    }
147
148    /// Serializes `value` and writes it to the store.
149    async fn write<T: Serialize>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()> {
150        let bytes = serde_json::to_vec(value).map_err(|error| {
151            Error::internal(format!("cache value could not be serialized: {error}"))
152        })?;
153        self.store.set(key.to_owned(), bytes, ttl).await
154    }
155}
156
157/// Normalizes a TTL: a zero duration means "never expire" (`None`).
158fn normalize_ttl(ttl: Option<Duration>) -> Option<Duration> {
159    match ttl {
160        Some(ttl) if ttl.is_zero() => None,
161        other => other,
162    }
163}
164
165impl FromRequest for Cache {
166    fn from_request(ctx: &RequestContext) -> impl Future<Output = Result<Self>> + Send {
167        let resolved = ctx
168            .state()
169            .get::<Cache>()
170            .map(|cache| (*cache).clone())
171            .ok_or_else(|| {
172                Error::internal("cache is not configured; call `App::cache(...)` to enable it")
173            });
174        async move { resolved }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    use std::sync::atomic::{AtomicUsize, Ordering};
183
184    use serde::Deserialize;
185
186    #[derive(Debug, PartialEq, Serialize, Deserialize)]
187    struct User {
188        id: i64,
189        name: String,
190    }
191
192    #[tokio::test]
193    async fn round_trips_a_typed_value() {
194        let cache = Cache::in_memory();
195        let user = User {
196            id: 1,
197            name: "alice".into(),
198        };
199        cache.set("user:1", &user).await.unwrap();
200
201        let stored: Option<User> = cache.get("user:1").await.unwrap();
202        assert_eq!(stored, Some(user));
203    }
204
205    #[tokio::test]
206    async fn a_missing_key_is_none() {
207        let cache = Cache::in_memory();
208        let stored: Option<String> = cache.get("absent").await.unwrap();
209        assert_eq!(stored, None);
210    }
211
212    #[tokio::test]
213    async fn an_entry_expires_after_its_ttl() {
214        let cache = Cache::in_memory();
215        cache
216            .set_ttl("k", &"v", Duration::from_millis(50))
217            .await
218            .unwrap();
219
220        assert_eq!(
221            cache.get::<String>("k").await.unwrap().as_deref(),
222            Some("v")
223        );
224        tokio::time::sleep(Duration::from_millis(120)).await;
225        assert_eq!(cache.get::<String>("k").await.unwrap(), None);
226    }
227
228    #[tokio::test]
229    async fn a_zero_ttl_never_expires() {
230        let cache = Cache::in_memory();
231        cache.set_ttl("k", &"v", Duration::ZERO).await.unwrap();
232
233        tokio::time::sleep(Duration::from_millis(80)).await;
234        assert_eq!(
235            cache.get::<String>("k").await.unwrap().as_deref(),
236            Some("v")
237        );
238    }
239
240    #[tokio::test]
241    async fn default_ttl_applies_to_plain_set() {
242        let cache = Cache::in_memory().default_ttl(Duration::from_millis(50));
243        cache.set("k", &"v").await.unwrap();
244
245        tokio::time::sleep(Duration::from_millis(120)).await;
246        assert_eq!(cache.get::<String>("k").await.unwrap(), None);
247    }
248
249    #[tokio::test]
250    async fn get_or_set_computes_once_then_hits_the_cache() {
251        let cache = Cache::in_memory();
252        let calls = AtomicUsize::new(0);
253
254        let compute = || async {
255            calls.fetch_add(1, Ordering::SeqCst);
256            Ok::<_, Error>(User {
257                id: 7,
258                name: "bob".into(),
259            })
260        };
261
262        let first = cache.get_or_set("user:7", None, compute).await.unwrap();
263        let second = cache.get_or_set("user:7", None, compute).await.unwrap();
264
265        assert_eq!(first, second);
266        assert_eq!(calls.load(Ordering::SeqCst), 1, "init runs only on a miss");
267    }
268
269    #[tokio::test]
270    async fn delete_and_clear_remove_entries() {
271        let cache = Cache::in_memory();
272        cache.set("a", &1).await.unwrap();
273        cache.set("b", &2).await.unwrap();
274
275        cache.delete("a").await.unwrap();
276        assert_eq!(cache.get::<i32>("a").await.unwrap(), None);
277        assert_eq!(cache.get::<i32>("b").await.unwrap(), Some(2));
278
279        cache.clear().await.unwrap();
280        assert_eq!(cache.get::<i32>("b").await.unwrap(), None);
281    }
282}