salvo_cache/
lib.rs

1//! Cache middleware for the Salvo web framework.
2//!
3//! Cache middleware for Salvo designed to intercept responses and cache them.
4//! This middleware will cache the response's StatusCode, Headers, and Body.
5//!
6//! You can define your custom [`CacheIssuer`] to determine which responses should be cached,
7//! or you can use the default [`RequestIssuer`].
8//!
9//! The default cache store is [`MokaStore`], which is a wrapper of [`moka`].
10//! You can define your own cache store by implementing [`CacheStore`].
11//!
12//! Example: [cache-simple](https://github.com/salvo-rs/salvo/tree/main/examples/cache-simple)
13//! Read more: <https://salvo.rs>
14#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
15#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
16#![cfg_attr(docsrs, feature(doc_cfg))]
17
18use std::borrow::Borrow;
19use std::collections::VecDeque;
20use std::error::Error as StdError;
21use std::fmt::{self, Debug, Formatter};
22use std::hash::Hash;
23
24use bytes::Bytes;
25use salvo_core::handler::Skipper;
26use salvo_core::http::{HeaderMap, ResBody, StatusCode};
27use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
28
29mod skipper;
30pub use skipper::MethodSkipper;
31
32#[macro_use]
33mod cfg;
34
35cfg_feature! {
36    #![feature = "moka-store"]
37
38    pub mod moka_store;
39    pub use moka_store::{MokaStore};
40}
41
42/// Issuer
43pub trait CacheIssuer: Send + Sync + 'static {
44    /// The key is used to identify the rate limit.
45    type Key: Hash + Eq + Send + Sync + 'static;
46    /// Issue a new key for the request. If it returns `None`, the request will not be cached.
47    fn issue(
48        &self,
49        req: &mut Request,
50        depot: &Depot,
51    ) -> impl Future<Output = Option<Self::Key>> + Send;
52}
53impl<F, K> CacheIssuer for F
54where
55    F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
56    K: Hash + Eq + Send + Sync + 'static,
57{
58    type Key = K;
59    async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
60        (self)(req, depot)
61    }
62}
63
64/// Identify user by Request Uri.
65#[derive(Clone, Debug)]
66pub struct RequestIssuer {
67    use_scheme: bool,
68    use_authority: bool,
69    use_path: bool,
70    use_query: bool,
71    use_method: bool,
72}
73impl Default for RequestIssuer {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78impl RequestIssuer {
79    /// Create a new `RequestIssuer`.
80    #[must_use]
81    pub fn new() -> Self {
82        Self {
83            use_scheme: true,
84            use_authority: true,
85            use_path: true,
86            use_query: true,
87            use_method: true,
88        }
89    }
90    /// Whether to use the request's URI scheme when generating the key.
91    #[must_use]
92    pub fn use_scheme(mut self, value: bool) -> Self {
93        self.use_scheme = value;
94        self
95    }
96    /// Whether to use the request's URI authority when generating the key.
97    #[must_use]
98    pub fn use_authority(mut self, value: bool) -> Self {
99        self.use_authority = value;
100        self
101    }
102    /// Whether to use the request's URI path when generating the key.
103    #[must_use]
104    pub fn use_path(mut self, value: bool) -> Self {
105        self.use_path = value;
106        self
107    }
108    /// Whether to use the request's URI query when generating the key.
109    #[must_use]
110    pub fn use_query(mut self, value: bool) -> Self {
111        self.use_query = value;
112        self
113    }
114    /// Whether to use the request method when generating the key.
115    #[must_use]
116    pub fn use_method(mut self, value: bool) -> Self {
117        self.use_method = value;
118        self
119    }
120}
121
122impl CacheIssuer for RequestIssuer {
123    type Key = String;
124    async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
125        let mut key = String::new();
126        if self.use_scheme
127            && let Some(scheme) = req.uri().scheme_str()
128        {
129            key.push_str(scheme);
130            key.push_str("://");
131        }
132        if self.use_authority
133            && let Some(authority) = req.uri().authority()
134        {
135            key.push_str(authority.as_str());
136        }
137        if self.use_path {
138            key.push_str(req.uri().path());
139        }
140        if self.use_query
141            && let Some(query) = req.uri().query()
142        {
143            key.push('?');
144            key.push_str(query);
145        }
146        if self.use_method {
147            key.push('|');
148            key.push_str(req.method().as_str());
149        }
150        Some(key)
151    }
152}
153
154/// Store cache.
155pub trait CacheStore: Send + Sync + 'static {
156    /// Error type for CacheStore.
157    type Error: StdError + Sync + Send + 'static;
158    /// Key
159    type Key: Hash + Eq + Send + Clone + 'static;
160    /// Get the cache item from the store.
161    fn load_entry<Q>(&self, key: &Q) -> impl Future<Output = Option<CachedEntry>> + Send
162    where
163        Self::Key: Borrow<Q>,
164        Q: Hash + Eq + Sync;
165    /// Save the cache item to the store.
166    fn save_entry(
167        &self,
168        key: Self::Key,
169        data: CachedEntry,
170    ) -> impl Future<Output = Result<(), Self::Error>> + Send;
171}
172
173/// `CachedBody` is used to save the response body to `CacheStore`.
174///
175/// [`ResBody`] has a Stream type, which is not `Send + Sync`, so we need to convert it to
176/// `CachedBody`. If the response's body is [`ResBody::Stream`], it will not be cached.
177#[derive(Clone, Debug, PartialEq)]
178#[non_exhaustive]
179pub enum CachedBody {
180    /// No body.
181    None,
182    /// Single bytes body.
183    Once(Bytes),
184    /// Chunks body.
185    Chunks(VecDeque<Bytes>),
186}
187impl TryFrom<&ResBody> for CachedBody {
188    type Error = Error;
189    fn try_from(body: &ResBody) -> Result<Self, Self::Error> {
190        match body {
191            ResBody::None => Ok(Self::None),
192            ResBody::Once(bytes) => Ok(Self::Once(bytes.to_owned())),
193            ResBody::Chunks(chunks) => Ok(Self::Chunks(chunks.to_owned())),
194            _ => Err(Error::other("unsupported body type")),
195        }
196    }
197}
198impl From<CachedBody> for ResBody {
199    fn from(body: CachedBody) -> Self {
200        match body {
201            CachedBody::None => Self::None,
202            CachedBody::Once(bytes) => Self::Once(bytes),
203            CachedBody::Chunks(chunks) => Self::Chunks(chunks),
204        }
205    }
206}
207
208/// Cached entry which will be stored in the cache store.
209#[derive(Clone, Debug)]
210#[non_exhaustive]
211pub struct CachedEntry {
212    /// Response status.
213    pub status: Option<StatusCode>,
214    /// Response headers.
215    pub headers: HeaderMap,
216    /// Response body.
217    ///
218    /// *Notice: If the response's body is streaming, it will be ignored and not cached.
219    pub body: CachedBody,
220}
221impl CachedEntry {
222    /// Create a new `CachedEntry`.
223    pub fn new(status: Option<StatusCode>, headers: HeaderMap, body: CachedBody) -> Self {
224        Self {
225            status,
226            headers,
227            body,
228        }
229    }
230
231    /// Get the response status.
232    pub fn status(&self) -> Option<StatusCode> {
233        self.status
234    }
235
236    /// Get the response headers.
237    pub fn headers(&self) -> &HeaderMap {
238        &self.headers
239    }
240
241    /// Get the response body.
242    ///
243    /// *Notice: If the response's body is streaming, it will be ignored and not cached.
244    pub fn body(&self) -> &CachedBody {
245        &self.body
246    }
247}
248
249/// Cache middleware.
250///
251/// # Example
252///
253/// ```
254/// use std::time::Duration;
255///
256/// use salvo_cache::{Cache, MokaStore, RequestIssuer};
257/// use salvo_core::Router;
258///
259/// let cache = Cache::new(
260///     MokaStore::builder()
261///         .time_to_live(Duration::from_secs(60))
262///         .build(),
263///     RequestIssuer::default(),
264/// );
265/// let router = Router::new().hoop(cache);
266/// ```
267#[non_exhaustive]
268pub struct Cache<S, I> {
269    /// Cache store.
270    pub store: S,
271    /// Cache issuer.
272    pub issuer: I,
273    /// Skipper.
274    pub skipper: Box<dyn Skipper>,
275}
276impl<S, I> Debug for Cache<S, I>
277where
278    S: Debug,
279    I: Debug,
280{
281    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
282        f.debug_struct("Cache")
283            .field("store", &self.store)
284            .field("issuer", &self.issuer)
285            .finish()
286    }
287}
288
289impl<S, I> Cache<S, I> {
290    /// Create a new `Cache`.
291    #[inline]
292    #[must_use]
293    pub fn new(store: S, issuer: I) -> Self {
294        let skipper = MethodSkipper::new().skip_all().skip_get(false);
295        Self {
296            store,
297            issuer,
298            skipper: Box::new(skipper),
299        }
300    }
301    /// Sets skipper and returns a new `Cache`.
302    #[inline]
303    #[must_use]
304    pub fn skipper(mut self, skipper: impl Skipper) -> Self {
305        self.skipper = Box::new(skipper);
306        self
307    }
308}
309
310#[async_trait]
311impl<S, I> Handler for Cache<S, I>
312where
313    S: CacheStore<Key = I::Key>,
314    I: CacheIssuer,
315{
316    async fn handle(
317        &self,
318        req: &mut Request,
319        depot: &mut Depot,
320        res: &mut Response,
321        ctrl: &mut FlowCtrl,
322    ) {
323        if self.skipper.skipped(req, depot) {
324            return;
325        }
326        let Some(key) = self.issuer.issue(req, depot).await else {
327            return;
328        };
329        let Some(cache) = self.store.load_entry(&key).await else {
330            ctrl.call_next(req, depot, res).await;
331            if !res.body.is_stream() && !res.body.is_error() {
332                let headers = res.headers().clone();
333                let body = TryInto::<CachedBody>::try_into(&res.body);
334                match body {
335                    Ok(body) => {
336                        let cached_data = CachedEntry::new(res.status_code, headers, body);
337                        if let Err(e) = self.store.save_entry(key, cached_data).await {
338                            tracing::error!(error = ?e, "cache failed");
339                        }
340                    }
341                    Err(e) => tracing::error!(error = ?e, "cache failed"),
342                }
343            }
344            return;
345        };
346        let CachedEntry {
347            status,
348            headers,
349            body,
350        } = cache;
351        if let Some(status) = status {
352            res.status_code(status);
353        }
354        *res.headers_mut() = headers;
355        *res.body_mut() = body.into();
356        ctrl.skip_rest();
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use salvo_core::prelude::*;
363    use salvo_core::test::{ResponseExt, TestClient};
364    use time::OffsetDateTime;
365
366    use super::*;
367
368    #[handler]
369    async fn cached() -> String {
370        format!(
371            "Hello World, my birth time is {}",
372            OffsetDateTime::now_utc()
373        )
374    }
375
376    #[tokio::test]
377    async fn test_cache() {
378        let cache = Cache::new(
379            MokaStore::builder()
380                .time_to_live(std::time::Duration::from_secs(5))
381                .build(),
382            RequestIssuer::default(),
383        );
384        let router = Router::new().hoop(cache).goal(cached);
385        let service = Service::new(router);
386
387        let mut res = TestClient::get("http://127.0.0.1:5801")
388            .send(&service)
389            .await;
390        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
391
392        let content0 = res.take_string().await.unwrap();
393
394        let mut res = TestClient::get("http://127.0.0.1:5801")
395            .send(&service)
396            .await;
397        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
398
399        let content1 = res.take_string().await.unwrap();
400        assert_eq!(content0, content1);
401
402        tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
403        let mut res = TestClient::post("http://127.0.0.1:5801")
404            .send(&service)
405            .await;
406        let content2 = res.take_string().await.unwrap();
407
408        assert_ne!(content0, content2);
409    }
410}