Skip to main content

salvo_cache/
lib.rs

1#![cfg_attr(test, allow(clippy::unwrap_used))]
2//! Response caching middleware for the Salvo web framework.
3//!
4//! This middleware intercepts HTTP responses and caches them for subsequent
5//! requests, reducing server load and improving response times for cacheable
6//! content.
7//!
8//! # What Gets Cached
9//!
10//! The cache stores the complete response including:
11//! - HTTP status code
12//! - Response headers
13//! - Response body (except for streaming responses)
14//!
15//! # Key Components
16//!
17//! - [`CacheIssuer`]: Determines the cache key for each request
18//! - [`CacheStore`]: Backend storage for cached responses
19//! - [`Cache`]: The middleware handler
20//!
21//! # Default Implementations
22//!
23//! - [`RequestIssuer`]: Generates cache keys from the request URI and method
24//! - [`MokaStore`]: High-performance concurrent cache backed by [`moka`]
25//!
26//! # Example
27//!
28//! ```ignore
29//! use std::time::Duration;
30//! use salvo_cache::{Cache, MokaStore, RequestIssuer};
31//! use salvo_core::prelude::*;
32//!
33//! let cache = Cache::new(
34//!     MokaStore::builder()
35//!         .time_to_live(Duration::from_secs(300))  // Cache for 5 minutes
36//!         .build(),
37//!     RequestIssuer::default(),
38//! );
39//!
40//! let router = Router::new()
41//!     .hoop(cache)
42//!     .get(my_expensive_handler);
43//! ```
44//!
45//! # Custom Cache Keys
46//!
47//! Implement [`CacheIssuer`] to customize cache key generation:
48//!
49//! ```ignore
50//! use salvo_cache::CacheIssuer;
51//!
52//! struct UserBasedIssuer;
53//! impl CacheIssuer for UserBasedIssuer {
54//!     type Key = String;
55//!
56//!     async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
57//!         // Cache per user + path
58//!         let user_id = depot.get::<String>("user_id").ok()?;
59//!         Some(format!("{}:{}", user_id, req.uri().path()))
60//!     }
61//! }
62//! ```
63//!
64//! # Skipping Cache
65//!
66//! By default, only GET requests are cached. Use the `skipper` method to customize:
67//!
68//! ```ignore
69//! let cache = Cache::new(store, issuer)
70//!     .skipper(|req, _depot| req.uri().path().starts_with("/api/"));
71//! ```
72//!
73//! # Limitations
74//!
75//! - Streaming responses ([`ResBody::Stream`]) cannot be cached
76//! - Error responses are not cached
77//!
78//! Read more: <https://salvo.rs>
79#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
80#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
81#![cfg_attr(docsrs, feature(doc_cfg))]
82
83use std::borrow::Borrow;
84use std::collections::VecDeque;
85use std::error::Error as StdError;
86use std::fmt::{self, Debug, Formatter};
87use std::hash::Hash;
88
89use bytes::Bytes;
90use salvo_core::handler::Skipper;
91use salvo_core::http::{HeaderMap, ResBody, StatusCode};
92use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
93
94mod skipper;
95pub use skipper::MethodSkipper;
96
97#[macro_use]
98mod cfg;
99
100cfg_feature! {
101    #![feature = "moka-store"]
102
103    pub mod moka_store;
104    pub use moka_store::{MokaStore};
105}
106
107/// Issuer
108pub trait CacheIssuer: Send + Sync + 'static {
109    /// The key is used to identify the rate limit.
110    type Key: Hash + Eq + Send + Sync + 'static;
111    /// Issue a new key for the request. If it returns `None`, the request will not be cached.
112    fn issue(
113        &self,
114        req: &mut Request,
115        depot: &Depot,
116    ) -> impl Future<Output = Option<Self::Key>> + Send;
117}
118impl<F, K> CacheIssuer for F
119where
120    F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
121    K: Hash + Eq + Send + Sync + 'static,
122{
123    type Key = K;
124    async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
125        (self)(req, depot)
126    }
127}
128
129/// Identify user by Request Uri.
130#[derive(Clone, Debug)]
131pub struct RequestIssuer {
132    use_scheme: bool,
133    use_authority: bool,
134    use_path: bool,
135    use_query: bool,
136    use_method: bool,
137}
138impl Default for RequestIssuer {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143impl RequestIssuer {
144    /// Create a new `RequestIssuer`.
145    #[must_use]
146    pub fn new() -> Self {
147        Self {
148            use_scheme: true,
149            use_authority: true,
150            use_path: true,
151            use_query: true,
152            use_method: true,
153        }
154    }
155    /// Whether to use the request's URI scheme when generating the key.
156    #[must_use]
157    pub fn use_scheme(mut self, value: bool) -> Self {
158        self.use_scheme = value;
159        self
160    }
161    /// Whether to use the request's URI authority when generating the key.
162    #[must_use]
163    pub fn use_authority(mut self, value: bool) -> Self {
164        self.use_authority = value;
165        self
166    }
167    /// Whether to use the request's URI path when generating the key.
168    #[must_use]
169    pub fn use_path(mut self, value: bool) -> Self {
170        self.use_path = value;
171        self
172    }
173    /// Whether to use the request's URI query when generating the key.
174    #[must_use]
175    pub fn use_query(mut self, value: bool) -> Self {
176        self.use_query = value;
177        self
178    }
179    /// Whether to use the request method when generating the key.
180    #[must_use]
181    pub fn use_method(mut self, value: bool) -> Self {
182        self.use_method = value;
183        self
184    }
185}
186
187impl CacheIssuer for RequestIssuer {
188    type Key = String;
189    async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
190        let mut key = String::with_capacity(req.uri().path().len() + 16);
191        if self.use_scheme
192            && let Some(scheme) = req.uri().scheme_str()
193        {
194            key.push_str(scheme);
195            key.push_str("://");
196        }
197        if self.use_authority
198            && let Some(authority) = req.uri().authority()
199        {
200            key.push_str(authority.as_str());
201        }
202        if self.use_path {
203            key.push_str(req.uri().path());
204        }
205        if self.use_query
206            && let Some(query) = req.uri().query()
207        {
208            key.push('?');
209            key.push_str(query);
210        }
211        if self.use_method {
212            key.push('|');
213            key.push_str(req.method().as_str());
214        }
215        Some(key)
216    }
217}
218
219/// Store cache.
220pub trait CacheStore: Send + Sync + 'static {
221    /// Error type for CacheStore.
222    type Error: StdError + Sync + Send + 'static;
223    /// Key
224    type Key: Hash + Eq + Send + Clone + 'static;
225    /// Get the cache item from the store.
226    fn load_entry<Q>(&self, key: &Q) -> impl Future<Output = Option<CachedEntry>> + Send
227    where
228        Self::Key: Borrow<Q>,
229        Q: Hash + Eq + Sync;
230    /// Save the cache item to the store.
231    fn save_entry(
232        &self,
233        key: Self::Key,
234        data: CachedEntry,
235    ) -> impl Future<Output = Result<(), Self::Error>> + Send;
236}
237
238/// `CachedBody` is used to save the response body to `CacheStore`.
239///
240/// [`ResBody`] has a Stream type, which is not `Send + Sync`, so we need to convert it to
241/// `CachedBody`. If the response's body is [`ResBody::Stream`], it will not be cached.
242#[derive(Clone, Debug, PartialEq)]
243#[non_exhaustive]
244pub enum CachedBody {
245    /// No body.
246    None,
247    /// Single bytes body.
248    Once(Bytes),
249    /// Chunks body.
250    Chunks(VecDeque<Bytes>),
251}
252impl TryFrom<&ResBody> for CachedBody {
253    type Error = Error;
254    fn try_from(body: &ResBody) -> Result<Self, Self::Error> {
255        match body {
256            ResBody::None => Ok(Self::None),
257            ResBody::Once(bytes) => Ok(Self::Once(bytes.to_owned())),
258            ResBody::Chunks(chunks) => Ok(Self::Chunks(chunks.to_owned())),
259            _ => Err(Error::other("unsupported body type")),
260        }
261    }
262}
263impl From<CachedBody> for ResBody {
264    fn from(body: CachedBody) -> Self {
265        match body {
266            CachedBody::None => Self::None,
267            CachedBody::Once(bytes) => Self::Once(bytes),
268            CachedBody::Chunks(chunks) => Self::Chunks(chunks),
269        }
270    }
271}
272
273/// Cached entry which will be stored in the cache store.
274#[derive(Clone, Debug)]
275#[non_exhaustive]
276pub struct CachedEntry {
277    /// Response status.
278    pub status: Option<StatusCode>,
279    /// Response headers.
280    pub headers: HeaderMap,
281    /// Response body.
282    ///
283    /// *Notice: If the response's body is streaming, it will be ignored and not cached.
284    pub body: CachedBody,
285}
286impl CachedEntry {
287    /// Create a new `CachedEntry`.
288    pub fn new(status: Option<StatusCode>, headers: HeaderMap, body: CachedBody) -> Self {
289        Self {
290            status,
291            headers,
292            body,
293        }
294    }
295
296    /// Get the response status.
297    pub fn status(&self) -> Option<StatusCode> {
298        self.status
299    }
300
301    /// Get the response headers.
302    pub fn headers(&self) -> &HeaderMap {
303        &self.headers
304    }
305
306    /// Get the response body.
307    ///
308    /// *Notice: If the response's body is streaming, it will be ignored and not cached.
309    pub fn body(&self) -> &CachedBody {
310        &self.body
311    }
312}
313
314/// Cache middleware.
315///
316/// # Example
317///
318/// ```
319/// use std::time::Duration;
320///
321/// use salvo_cache::{Cache, MokaStore, RequestIssuer};
322/// use salvo_core::Router;
323///
324/// let cache = Cache::new(
325///     MokaStore::builder()
326///         .time_to_live(Duration::from_secs(60))
327///         .build(),
328///     RequestIssuer::default(),
329/// );
330/// let router = Router::new().hoop(cache);
331/// ```
332#[non_exhaustive]
333pub struct Cache<S, I> {
334    /// Cache store.
335    pub store: S,
336    /// Cache issuer.
337    pub issuer: I,
338    /// Skipper.
339    pub skipper: Box<dyn Skipper>,
340}
341impl<S, I> Debug for Cache<S, I>
342where
343    S: Debug,
344    I: Debug,
345{
346    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
347        f.debug_struct("Cache")
348            .field("store", &self.store)
349            .field("issuer", &self.issuer)
350            .finish()
351    }
352}
353
354impl<S, I> Cache<S, I> {
355    /// Create a new `Cache`.
356    #[inline]
357    #[must_use]
358    pub fn new(store: S, issuer: I) -> Self {
359        let skipper = MethodSkipper::new().skip_all().skip_get(false);
360        Self {
361            store,
362            issuer,
363            skipper: Box::new(skipper),
364        }
365    }
366    /// Sets skipper and returns a new `Cache`.
367    #[inline]
368    #[must_use]
369    pub fn skipper(mut self, skipper: impl Skipper) -> Self {
370        self.skipper = Box::new(skipper);
371        self
372    }
373}
374
375#[async_trait]
376impl<S, I> Handler for Cache<S, I>
377where
378    S: CacheStore<Key = I::Key>,
379    I: CacheIssuer,
380{
381    async fn handle(
382        &self,
383        req: &mut Request,
384        depot: &mut Depot,
385        res: &mut Response,
386        ctrl: &mut FlowCtrl,
387    ) {
388        if self.skipper.skipped(req, depot) {
389            ctrl.call_next(req, depot, res).await;
390            return;
391        }
392        let Some(key) = self.issuer.issue(req, depot).await else {
393            return;
394        };
395        let Some(cache) = self.store.load_entry(&key).await else {
396            ctrl.call_next(req, depot, res).await;
397            if !res.body.is_stream() && !res.body.is_error() {
398                let headers = res.headers().clone();
399                let body = TryInto::<CachedBody>::try_into(&res.body);
400                match body {
401                    Ok(body) => {
402                        let cached_data = CachedEntry::new(res.status_code, headers, body);
403                        if let Err(e) = self.store.save_entry(key, cached_data).await {
404                            tracing::error!(error = ?e, "cache failed");
405                        }
406                    }
407                    Err(e) => tracing::error!(error = ?e, "cache failed"),
408                }
409            }
410            return;
411        };
412        let CachedEntry {
413            status,
414            headers,
415            body,
416        } = cache;
417        if let Some(status) = status {
418            res.status_code(status);
419        }
420        *res.headers_mut() = headers;
421        *res.body_mut() = body.into();
422        ctrl.skip_rest();
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use std::collections::VecDeque;
429
430    use bytes::Bytes;
431    use salvo_core::http::HeaderMap;
432    use salvo_core::prelude::*;
433    use salvo_core::test::{ResponseExt, TestClient};
434    use time::OffsetDateTime;
435
436    use super::*;
437
438    #[handler]
439    async fn cached() -> String {
440        format!(
441            "Hello World, my birth time is {}",
442            OffsetDateTime::now_utc()
443        )
444    }
445
446    #[tokio::test]
447    async fn test_cache() {
448        let cache = Cache::new(
449            MokaStore::builder()
450                .time_to_live(std::time::Duration::from_secs(5))
451                .build(),
452            RequestIssuer::default(),
453        );
454        let router = Router::new().hoop(cache).goal(cached);
455        let service = Service::new(router);
456
457        let mut res = TestClient::get("http://127.0.0.1:5801")
458            .send(&service)
459            .await;
460        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
461
462        let content0 = res.take_string().await.unwrap();
463
464        let mut res = TestClient::get("http://127.0.0.1:5801")
465            .send(&service)
466            .await;
467        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
468
469        let content1 = res.take_string().await.unwrap();
470        assert_eq!(content0, content1);
471
472        tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
473        let mut res = TestClient::post("http://127.0.0.1:5801")
474            .send(&service)
475            .await;
476        let content2 = res.take_string().await.unwrap();
477
478        assert_ne!(content0, content2);
479    }
480
481    // Tests for RequestIssuer
482    #[test]
483    fn test_request_issuer_new() {
484        let issuer = RequestIssuer::new();
485        assert!(issuer.use_scheme);
486        assert!(issuer.use_authority);
487        assert!(issuer.use_path);
488        assert!(issuer.use_query);
489        assert!(issuer.use_method);
490    }
491
492    #[test]
493    fn test_request_issuer_default() {
494        let issuer = RequestIssuer::default();
495        assert!(issuer.use_scheme);
496        assert!(issuer.use_authority);
497        assert!(issuer.use_path);
498        assert!(issuer.use_query);
499        assert!(issuer.use_method);
500    }
501
502    #[test]
503    fn test_request_issuer_use_scheme() {
504        let issuer = RequestIssuer::new().use_scheme(false);
505        assert!(!issuer.use_scheme);
506        assert!(issuer.use_authority);
507    }
508
509    #[test]
510    fn test_request_issuer_use_authority() {
511        let issuer = RequestIssuer::new().use_authority(false);
512        assert!(issuer.use_scheme);
513        assert!(!issuer.use_authority);
514    }
515
516    #[test]
517    fn test_request_issuer_use_path() {
518        let issuer = RequestIssuer::new().use_path(false);
519        assert!(!issuer.use_path);
520    }
521
522    #[test]
523    fn test_request_issuer_use_query() {
524        let issuer = RequestIssuer::new().use_query(false);
525        assert!(!issuer.use_query);
526    }
527
528    #[test]
529    fn test_request_issuer_use_method() {
530        let issuer = RequestIssuer::new().use_method(false);
531        assert!(!issuer.use_method);
532    }
533
534    #[test]
535    fn test_request_issuer_chain() {
536        let issuer = RequestIssuer::new()
537            .use_scheme(false)
538            .use_authority(false)
539            .use_path(true)
540            .use_query(false)
541            .use_method(true);
542        assert!(!issuer.use_scheme);
543        assert!(!issuer.use_authority);
544        assert!(issuer.use_path);
545        assert!(!issuer.use_query);
546        assert!(issuer.use_method);
547    }
548
549    #[test]
550    fn test_request_issuer_debug() {
551        let issuer = RequestIssuer::new();
552        let debug_str = format!("{issuer:?}");
553        assert!(debug_str.contains("RequestIssuer"));
554        assert!(debug_str.contains("use_scheme"));
555    }
556
557    #[test]
558    fn test_request_issuer_clone() {
559        let issuer = RequestIssuer::new().use_scheme(false);
560        let cloned = issuer.clone();
561        assert_eq!(issuer.use_scheme, cloned.use_scheme);
562        assert_eq!(issuer.use_authority, cloned.use_authority);
563    }
564
565    // Tests for CachedBody
566    #[test]
567    fn test_cached_body_none() {
568        let body = CachedBody::None;
569        assert_eq!(body, CachedBody::None);
570    }
571
572    #[test]
573    fn test_cached_body_once() {
574        let bytes = Bytes::from("test data");
575        let body = CachedBody::Once(bytes.clone());
576        assert_eq!(body, CachedBody::Once(bytes));
577    }
578
579    #[test]
580    fn test_cached_body_chunks() {
581        let mut chunks = VecDeque::new();
582        chunks.push_back(Bytes::from("chunk1"));
583        chunks.push_back(Bytes::from("chunk2"));
584        let body = CachedBody::Chunks(chunks.clone());
585        assert_eq!(body, CachedBody::Chunks(chunks));
586    }
587
588    #[test]
589    fn test_cached_body_try_from_res_body_none() {
590        let res_body = ResBody::None;
591        let result: Result<CachedBody, _> = (&res_body).try_into();
592        assert_eq!(result.unwrap(), CachedBody::None);
593    }
594
595    #[test]
596    fn test_cached_body_try_from_res_body_once() {
597        let bytes = Bytes::from("test");
598        let res_body = ResBody::Once(bytes.clone());
599        let result: Result<CachedBody, _> = (&res_body).try_into();
600        assert_eq!(result.unwrap(), CachedBody::Once(bytes));
601    }
602
603    #[test]
604    fn test_cached_body_try_from_res_body_chunks() {
605        let mut chunks = VecDeque::new();
606        chunks.push_back(Bytes::from("chunk1"));
607        chunks.push_back(Bytes::from("chunk2"));
608        let res_body = ResBody::Chunks(chunks.clone());
609        let result: Result<CachedBody, _> = (&res_body).try_into();
610        assert_eq!(result.unwrap(), CachedBody::Chunks(chunks));
611    }
612
613    #[test]
614    fn test_cached_body_into_res_body_none() {
615        let cb = CachedBody::None;
616        let res_body: ResBody = cb.into();
617        assert!(matches!(res_body, ResBody::None));
618    }
619
620    #[test]
621    fn test_cached_body_into_res_body_once() {
622        let bytes = Bytes::from("test");
623        let cb = CachedBody::Once(bytes.clone());
624        let res_body: ResBody = cb.into();
625        assert!(matches!(res_body, ResBody::Once(b) if b == bytes));
626    }
627
628    #[test]
629    fn test_cached_body_into_res_body_chunks() {
630        let mut chunks = VecDeque::new();
631        chunks.push_back(Bytes::from("chunk1"));
632        let cb = CachedBody::Chunks(chunks);
633        let res_body: ResBody = cb.into();
634        assert!(matches!(res_body, ResBody::Chunks(_)));
635    }
636
637    #[test]
638    fn test_cached_body_debug() {
639        let body = CachedBody::None;
640        let debug_str = format!("{body:?}");
641        assert!(debug_str.contains("None"));
642
643        let body = CachedBody::Once(Bytes::from("test"));
644        let debug_str = format!("{body:?}");
645        assert!(debug_str.contains("Once"));
646    }
647
648    #[test]
649    fn test_cached_body_clone() {
650        let body = CachedBody::Once(Bytes::from("test"));
651        let cloned = body.clone();
652        assert_eq!(body, cloned);
653    }
654
655    // Tests for CachedEntry
656    #[test]
657    fn test_cached_entry_new() {
658        let entry = CachedEntry::new(Some(StatusCode::OK), HeaderMap::new(), CachedBody::None);
659        assert_eq!(entry.status, Some(StatusCode::OK));
660        assert!(entry.headers.is_empty());
661        assert_eq!(entry.body, CachedBody::None);
662    }
663
664    #[test]
665    fn test_cached_entry_status() {
666        let entry = CachedEntry::new(
667            Some(StatusCode::NOT_FOUND),
668            HeaderMap::new(),
669            CachedBody::None,
670        );
671        assert_eq!(entry.status(), Some(StatusCode::NOT_FOUND));
672    }
673
674    #[test]
675    fn test_cached_entry_status_none() {
676        let entry = CachedEntry::new(None, HeaderMap::new(), CachedBody::None);
677        assert_eq!(entry.status(), None);
678    }
679
680    #[test]
681    fn test_cached_entry_headers() {
682        let mut headers = HeaderMap::new();
683        headers.insert("Content-Type", "application/json".parse().unwrap());
684        let entry = CachedEntry::new(Some(StatusCode::OK), headers.clone(), CachedBody::None);
685        assert_eq!(entry.headers().len(), 1);
686        assert!(entry.headers().contains_key("Content-Type"));
687    }
688
689    #[test]
690    fn test_cached_entry_body() {
691        let body = CachedBody::Once(Bytes::from("test body"));
692        let entry = CachedEntry::new(Some(StatusCode::OK), HeaderMap::new(), body.clone());
693        assert_eq!(entry.body(), &body);
694    }
695
696    #[test]
697    fn test_cached_entry_debug() {
698        let entry = CachedEntry::new(Some(StatusCode::OK), HeaderMap::new(), CachedBody::None);
699        let debug_str = format!("{entry:?}");
700        assert!(debug_str.contains("CachedEntry"));
701        assert!(debug_str.contains("status"));
702    }
703
704    #[test]
705    fn test_cached_entry_clone() {
706        let entry = CachedEntry::new(
707            Some(StatusCode::OK),
708            HeaderMap::new(),
709            CachedBody::Once(Bytes::from("test")),
710        );
711        let cloned = entry.clone();
712        assert_eq!(entry.status, cloned.status);
713        assert_eq!(entry.body, cloned.body);
714    }
715
716    // Tests for Cache
717    #[test]
718    fn test_cache_new() {
719        let cache = Cache::new(MokaStore::<String>::new(100), RequestIssuer::default());
720        assert!(format!("{cache:?}").contains("Cache"));
721    }
722
723    #[test]
724    fn test_cache_debug() {
725        let cache = Cache::new(MokaStore::<String>::new(100), RequestIssuer::default());
726        let debug_str = format!("{cache:?}");
727        assert!(debug_str.contains("Cache"));
728        assert!(debug_str.contains("store"));
729        assert!(debug_str.contains("issuer"));
730    }
731
732    #[tokio::test]
733    async fn test_cache_same_path_same_content() {
734        let cache = Cache::new(
735            MokaStore::builder()
736                .time_to_live(std::time::Duration::from_secs(60))
737                .build(),
738            RequestIssuer::default(),
739        );
740        let router = Router::new().hoop(cache).goal(cached);
741        let service = Service::new(router);
742
743        let mut res1 = TestClient::get("http://127.0.0.1:5801/same-path")
744            .send(&service)
745            .await;
746        let content1 = res1.take_string().await.unwrap();
747
748        let mut res2 = TestClient::get("http://127.0.0.1:5801/same-path")
749            .send(&service)
750            .await;
751        let content2 = res2.take_string().await.unwrap();
752
753        // Same path should return cached content
754        assert_eq!(content1, content2);
755    }
756}