Skip to main content

salvo_cache/
lib.rs

1//! Response caching middleware for the Salvo web framework.
2//!
3//! This middleware intercepts HTTP responses and caches them for subsequent
4//! requests, reducing server load and improving response times for cacheable
5//! content.
6//!
7//! # What Gets Cached
8//!
9//! The cache stores the complete response including:
10//! - HTTP status code
11//! - Response headers
12//! - Response body (except for streaming responses)
13//!
14//! # Key Components
15//!
16//! - [`CacheIssuer`]: Determines the cache key for each request
17//! - [`CacheStore`]: Backend storage for cached responses
18//! - [`Cache`]: The middleware handler
19//!
20//! # Default Implementations
21//!
22//! - [`RequestIssuer`]: Generates cache keys from the request URI and method
23//! - [`MokaStore`]: High-performance concurrent cache backed by [`moka`]
24//!
25//! # Example
26//!
27//! ```ignore
28//! use std::time::Duration;
29//! use salvo_cache::{Cache, MokaStore, RequestIssuer};
30//! use salvo_core::prelude::*;
31//!
32//! let cache = Cache::new(
33//!     MokaStore::builder()
34//!         .time_to_live(Duration::from_secs(300))  // Cache for 5 minutes
35//!         .build(),
36//!     RequestIssuer::default(),
37//! );
38//!
39//! let router = Router::new()
40//!     .hoop(cache)
41//!     .get(my_expensive_handler);
42//! ```
43//!
44//! # Custom Cache Keys
45//!
46//! Implement [`CacheIssuer`] to customize cache key generation:
47//!
48//! ```ignore
49//! use salvo_cache::CacheIssuer;
50//!
51//! struct UserBasedIssuer;
52//! impl CacheIssuer for UserBasedIssuer {
53//!     type Key = String;
54//!
55//!     async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
56//!         // Cache per user + path
57//!         let user_id = depot.get::<String>("user_id").ok()?;
58//!         Some(format!("{}:{}", user_id, req.uri().path()))
59//!     }
60//! }
61//! ```
62//!
63//! # Skipping Cache
64//!
65//! By default, only GET requests are cached. Use the `skipper` method to customize:
66//!
67//! ```ignore
68//! let cache = Cache::new(store, issuer)
69//!     .skipper(|req, _depot| req.uri().path().starts_with("/api/"));
70//! ```
71//!
72//! # Limitations
73//!
74//! - Streaming responses ([`ResBody::Stream`]) cannot be cached
75//! - Error responses are not cached
76//!
77//! Read more: <https://salvo.rs>
78#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
79#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
80#![cfg_attr(docsrs, feature(doc_cfg))]
81
82use std::borrow::Borrow;
83use std::collections::VecDeque;
84use std::error::Error as StdError;
85use std::fmt::{self, Debug, Formatter};
86use std::hash::Hash;
87
88use bytes::Bytes;
89use salvo_core::handler::Skipper;
90use salvo_core::http::{HeaderMap, ResBody, StatusCode};
91use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
92
93mod skipper;
94pub use skipper::MethodSkipper;
95
96#[macro_use]
97mod cfg;
98
99cfg_feature! {
100    #![feature = "moka-store"]
101
102    pub mod moka_store;
103    pub use moka_store::{MokaStore};
104}
105
106/// Issuer
107pub trait CacheIssuer: Send + Sync + 'static {
108    /// The key is used to identify the rate limit.
109    type Key: Hash + Eq + Send + Sync + 'static;
110    /// Issue a new key for the request. If it returns `None`, the request will not be cached.
111    fn issue(
112        &self,
113        req: &mut Request,
114        depot: &Depot,
115    ) -> impl Future<Output = Option<Self::Key>> + Send;
116}
117impl<F, K> CacheIssuer for F
118where
119    F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
120    K: Hash + Eq + Send + Sync + 'static,
121{
122    type Key = K;
123    async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
124        (self)(req, depot)
125    }
126}
127
128/// Identify user by Request Uri.
129#[derive(Clone, Debug)]
130pub struct RequestIssuer {
131    use_scheme: bool,
132    use_authority: bool,
133    use_path: bool,
134    use_query: bool,
135    use_method: bool,
136}
137impl Default for RequestIssuer {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142impl RequestIssuer {
143    /// Create a new `RequestIssuer`.
144    #[must_use]
145    pub fn new() -> Self {
146        Self {
147            use_scheme: true,
148            use_authority: true,
149            use_path: true,
150            use_query: true,
151            use_method: true,
152        }
153    }
154    /// Whether to use the request's URI scheme when generating the key.
155    #[must_use]
156    pub fn use_scheme(mut self, value: bool) -> Self {
157        self.use_scheme = value;
158        self
159    }
160    /// Whether to use the request's URI authority when generating the key.
161    #[must_use]
162    pub fn use_authority(mut self, value: bool) -> Self {
163        self.use_authority = value;
164        self
165    }
166    /// Whether to use the request's URI path when generating the key.
167    #[must_use]
168    pub fn use_path(mut self, value: bool) -> Self {
169        self.use_path = value;
170        self
171    }
172    /// Whether to use the request's URI query when generating the key.
173    #[must_use]
174    pub fn use_query(mut self, value: bool) -> Self {
175        self.use_query = value;
176        self
177    }
178    /// Whether to use the request method when generating the key.
179    #[must_use]
180    pub fn use_method(mut self, value: bool) -> Self {
181        self.use_method = value;
182        self
183    }
184}
185
186impl CacheIssuer for RequestIssuer {
187    type Key = String;
188    async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
189        let mut key = String::new();
190        if self.use_scheme
191            && let Some(scheme) = req.uri().scheme_str()
192        {
193            key.push_str(scheme);
194            key.push_str("://");
195        }
196        if self.use_authority
197            && let Some(authority) = req.uri().authority()
198        {
199            key.push_str(authority.as_str());
200        }
201        if self.use_path {
202            key.push_str(req.uri().path());
203        }
204        if self.use_query
205            && let Some(query) = req.uri().query()
206        {
207            key.push('?');
208            key.push_str(query);
209        }
210        if self.use_method {
211            key.push('|');
212            key.push_str(req.method().as_str());
213        }
214        Some(key)
215    }
216}
217
218/// Store cache.
219pub trait CacheStore: Send + Sync + 'static {
220    /// Error type for CacheStore.
221    type Error: StdError + Sync + Send + 'static;
222    /// Key
223    type Key: Hash + Eq + Send + Clone + 'static;
224    /// Get the cache item from the store.
225    fn load_entry<Q>(&self, key: &Q) -> impl Future<Output = Option<CachedEntry>> + Send
226    where
227        Self::Key: Borrow<Q>,
228        Q: Hash + Eq + Sync;
229    /// Save the cache item to the store.
230    fn save_entry(
231        &self,
232        key: Self::Key,
233        data: CachedEntry,
234    ) -> impl Future<Output = Result<(), Self::Error>> + Send;
235}
236
237/// `CachedBody` is used to save the response body to `CacheStore`.
238///
239/// [`ResBody`] has a Stream type, which is not `Send + Sync`, so we need to convert it to
240/// `CachedBody`. If the response's body is [`ResBody::Stream`], it will not be cached.
241#[derive(Clone, Debug, PartialEq)]
242#[non_exhaustive]
243pub enum CachedBody {
244    /// No body.
245    None,
246    /// Single bytes body.
247    Once(Bytes),
248    /// Chunks body.
249    Chunks(VecDeque<Bytes>),
250}
251impl TryFrom<&ResBody> for CachedBody {
252    type Error = Error;
253    fn try_from(body: &ResBody) -> Result<Self, Self::Error> {
254        match body {
255            ResBody::None => Ok(Self::None),
256            ResBody::Once(bytes) => Ok(Self::Once(bytes.to_owned())),
257            ResBody::Chunks(chunks) => Ok(Self::Chunks(chunks.to_owned())),
258            _ => Err(Error::other("unsupported body type")),
259        }
260    }
261}
262impl From<CachedBody> for ResBody {
263    fn from(body: CachedBody) -> Self {
264        match body {
265            CachedBody::None => Self::None,
266            CachedBody::Once(bytes) => Self::Once(bytes),
267            CachedBody::Chunks(chunks) => Self::Chunks(chunks),
268        }
269    }
270}
271
272/// Cached entry which will be stored in the cache store.
273#[derive(Clone, Debug)]
274#[non_exhaustive]
275pub struct CachedEntry {
276    /// Response status.
277    pub status: Option<StatusCode>,
278    /// Response headers.
279    pub headers: HeaderMap,
280    /// Response body.
281    ///
282    /// *Notice: If the response's body is streaming, it will be ignored and not cached.
283    pub body: CachedBody,
284}
285impl CachedEntry {
286    /// Create a new `CachedEntry`.
287    pub fn new(status: Option<StatusCode>, headers: HeaderMap, body: CachedBody) -> Self {
288        Self {
289            status,
290            headers,
291            body,
292        }
293    }
294
295    /// Get the response status.
296    pub fn status(&self) -> Option<StatusCode> {
297        self.status
298    }
299
300    /// Get the response headers.
301    pub fn headers(&self) -> &HeaderMap {
302        &self.headers
303    }
304
305    /// Get the response body.
306    ///
307    /// *Notice: If the response's body is streaming, it will be ignored and not cached.
308    pub fn body(&self) -> &CachedBody {
309        &self.body
310    }
311}
312
313/// Cache middleware.
314///
315/// # Example
316///
317/// ```
318/// use std::time::Duration;
319///
320/// use salvo_cache::{Cache, MokaStore, RequestIssuer};
321/// use salvo_core::Router;
322///
323/// let cache = Cache::new(
324///     MokaStore::builder()
325///         .time_to_live(Duration::from_secs(60))
326///         .build(),
327///     RequestIssuer::default(),
328/// );
329/// let router = Router::new().hoop(cache);
330/// ```
331#[non_exhaustive]
332pub struct Cache<S, I> {
333    /// Cache store.
334    pub store: S,
335    /// Cache issuer.
336    pub issuer: I,
337    /// Skipper.
338    pub skipper: Box<dyn Skipper>,
339}
340impl<S, I> Debug for Cache<S, I>
341where
342    S: Debug,
343    I: Debug,
344{
345    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
346        f.debug_struct("Cache")
347            .field("store", &self.store)
348            .field("issuer", &self.issuer)
349            .finish()
350    }
351}
352
353impl<S, I> Cache<S, I> {
354    /// Create a new `Cache`.
355    #[inline]
356    #[must_use]
357    pub fn new(store: S, issuer: I) -> Self {
358        let skipper = MethodSkipper::new().skip_all().skip_get(false);
359        Self {
360            store,
361            issuer,
362            skipper: Box::new(skipper),
363        }
364    }
365    /// Sets skipper and returns a new `Cache`.
366    #[inline]
367    #[must_use]
368    pub fn skipper(mut self, skipper: impl Skipper) -> Self {
369        self.skipper = Box::new(skipper);
370        self
371    }
372}
373
374#[async_trait]
375impl<S, I> Handler for Cache<S, I>
376where
377    S: CacheStore<Key = I::Key>,
378    I: CacheIssuer,
379{
380    async fn handle(
381        &self,
382        req: &mut Request,
383        depot: &mut Depot,
384        res: &mut Response,
385        ctrl: &mut FlowCtrl,
386    ) {
387        if self.skipper.skipped(req, depot) {
388            return;
389        }
390        let Some(key) = self.issuer.issue(req, depot).await else {
391            return;
392        };
393        let Some(cache) = self.store.load_entry(&key).await else {
394            ctrl.call_next(req, depot, res).await;
395            if !res.body.is_stream() && !res.body.is_error() {
396                let headers = res.headers().clone();
397                let body = TryInto::<CachedBody>::try_into(&res.body);
398                match body {
399                    Ok(body) => {
400                        let cached_data = CachedEntry::new(res.status_code, headers, body);
401                        if let Err(e) = self.store.save_entry(key, cached_data).await {
402                            tracing::error!(error = ?e, "cache failed");
403                        }
404                    }
405                    Err(e) => tracing::error!(error = ?e, "cache failed"),
406                }
407            }
408            return;
409        };
410        let CachedEntry {
411            status,
412            headers,
413            body,
414        } = cache;
415        if let Some(status) = status {
416            res.status_code(status);
417        }
418        *res.headers_mut() = headers;
419        *res.body_mut() = body.into();
420        ctrl.skip_rest();
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use salvo_core::prelude::*;
427    use salvo_core::test::{ResponseExt, TestClient};
428    use time::OffsetDateTime;
429
430    use super::*;
431
432    #[handler]
433    async fn cached() -> String {
434        format!(
435            "Hello World, my birth time is {}",
436            OffsetDateTime::now_utc()
437        )
438    }
439
440    #[tokio::test]
441    async fn test_cache() {
442        let cache = Cache::new(
443            MokaStore::builder()
444                .time_to_live(std::time::Duration::from_secs(5))
445                .build(),
446            RequestIssuer::default(),
447        );
448        let router = Router::new().hoop(cache).goal(cached);
449        let service = Service::new(router);
450
451        let mut res = TestClient::get("http://127.0.0.1:5801")
452            .send(&service)
453            .await;
454        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
455
456        let content0 = res.take_string().await.unwrap();
457
458        let mut res = TestClient::get("http://127.0.0.1:5801")
459            .send(&service)
460            .await;
461        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
462
463        let content1 = res.take_string().await.unwrap();
464        assert_eq!(content0, content1);
465
466        tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
467        let mut res = TestClient::post("http://127.0.0.1:5801")
468            .send(&service)
469            .await;
470        let content2 = res.take_string().await.unwrap();
471
472        assert_ne!(content0, content2);
473    }
474}