1#![cfg_attr(test, allow(clippy::unwrap_used))]
2#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
80#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
81#![cfg_attr(docsrs, feature(doc_cfg))]
82
83use std::borrow::Borrow;
84use std::collections::VecDeque;
85use std::error::Error as StdError;
86use std::fmt::{self, Debug, Formatter};
87use std::hash::Hash;
88
89use bytes::Bytes;
90use salvo_core::handler::Skipper;
91use salvo_core::http::{HeaderMap, ResBody, StatusCode};
92use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
93
94mod skipper;
95pub use skipper::MethodSkipper;
96
97#[macro_use]
98mod cfg;
99
100cfg_feature! {
101 #![feature = "moka-store"]
102
103 pub mod moka_store;
104 pub use moka_store::{MokaStore};
105}
106
107pub trait CacheIssuer: Send + Sync + 'static {
109 type Key: Hash + Eq + Send + Sync + 'static;
111 fn issue(
113 &self,
114 req: &mut Request,
115 depot: &Depot,
116 ) -> impl Future<Output = Option<Self::Key>> + Send;
117}
118impl<F, K> CacheIssuer for F
119where
120 F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
121 K: Hash + Eq + Send + Sync + 'static,
122{
123 type Key = K;
124 async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
125 (self)(req, depot)
126 }
127}
128
129#[derive(Clone, Debug)]
131pub struct RequestIssuer {
132 use_scheme: bool,
133 use_authority: bool,
134 use_path: bool,
135 use_query: bool,
136 use_method: bool,
137}
138impl Default for RequestIssuer {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143impl RequestIssuer {
144 #[must_use]
146 pub fn new() -> Self {
147 Self {
148 use_scheme: true,
149 use_authority: true,
150 use_path: true,
151 use_query: true,
152 use_method: true,
153 }
154 }
155 #[must_use]
157 pub fn use_scheme(mut self, value: bool) -> Self {
158 self.use_scheme = value;
159 self
160 }
161 #[must_use]
163 pub fn use_authority(mut self, value: bool) -> Self {
164 self.use_authority = value;
165 self
166 }
167 #[must_use]
169 pub fn use_path(mut self, value: bool) -> Self {
170 self.use_path = value;
171 self
172 }
173 #[must_use]
175 pub fn use_query(mut self, value: bool) -> Self {
176 self.use_query = value;
177 self
178 }
179 #[must_use]
181 pub fn use_method(mut self, value: bool) -> Self {
182 self.use_method = value;
183 self
184 }
185}
186
187impl CacheIssuer for RequestIssuer {
188 type Key = String;
189 async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
190 let mut key = String::with_capacity(req.uri().path().len() + 16);
191 if self.use_scheme
192 && let Some(scheme) = req.uri().scheme_str()
193 {
194 key.push_str(scheme);
195 key.push_str("://");
196 }
197 if self.use_authority
198 && let Some(authority) = req.uri().authority()
199 {
200 key.push_str(authority.as_str());
201 }
202 if self.use_path {
203 key.push_str(req.uri().path());
204 }
205 if self.use_query
206 && let Some(query) = req.uri().query()
207 {
208 key.push('?');
209 key.push_str(query);
210 }
211 if self.use_method {
212 key.push('|');
213 key.push_str(req.method().as_str());
214 }
215 Some(key)
216 }
217}
218
219pub trait CacheStore: Send + Sync + 'static {
221 type Error: StdError + Sync + Send + 'static;
223 type Key: Hash + Eq + Send + Clone + 'static;
225 fn load_entry<Q>(&self, key: &Q) -> impl Future<Output = Option<CachedEntry>> + Send
227 where
228 Self::Key: Borrow<Q>,
229 Q: Hash + Eq + Sync;
230 fn save_entry(
232 &self,
233 key: Self::Key,
234 data: CachedEntry,
235 ) -> impl Future<Output = Result<(), Self::Error>> + Send;
236}
237
238#[derive(Clone, Debug, PartialEq)]
243#[non_exhaustive]
244pub enum CachedBody {
245 None,
247 Once(Bytes),
249 Chunks(VecDeque<Bytes>),
251}
252impl TryFrom<&ResBody> for CachedBody {
253 type Error = Error;
254 fn try_from(body: &ResBody) -> Result<Self, Self::Error> {
255 match body {
256 ResBody::None => Ok(Self::None),
257 ResBody::Once(bytes) => Ok(Self::Once(bytes.to_owned())),
258 ResBody::Chunks(chunks) => Ok(Self::Chunks(chunks.to_owned())),
259 _ => Err(Error::other("unsupported body type")),
260 }
261 }
262}
263impl From<CachedBody> for ResBody {
264 fn from(body: CachedBody) -> Self {
265 match body {
266 CachedBody::None => Self::None,
267 CachedBody::Once(bytes) => Self::Once(bytes),
268 CachedBody::Chunks(chunks) => Self::Chunks(chunks),
269 }
270 }
271}
272
273#[derive(Clone, Debug)]
275#[non_exhaustive]
276pub struct CachedEntry {
277 pub status: Option<StatusCode>,
279 pub headers: HeaderMap,
281 pub body: CachedBody,
285}
286impl CachedEntry {
287 pub fn new(status: Option<StatusCode>, headers: HeaderMap, body: CachedBody) -> Self {
289 Self {
290 status,
291 headers,
292 body,
293 }
294 }
295
296 pub fn status(&self) -> Option<StatusCode> {
298 self.status
299 }
300
301 pub fn headers(&self) -> &HeaderMap {
303 &self.headers
304 }
305
306 pub fn body(&self) -> &CachedBody {
310 &self.body
311 }
312}
313
314#[non_exhaustive]
333pub struct Cache<S, I> {
334 pub store: S,
336 pub issuer: I,
338 pub skipper: Box<dyn Skipper>,
340}
341impl<S, I> Debug for Cache<S, I>
342where
343 S: Debug,
344 I: Debug,
345{
346 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
347 f.debug_struct("Cache")
348 .field("store", &self.store)
349 .field("issuer", &self.issuer)
350 .finish()
351 }
352}
353
354impl<S, I> Cache<S, I> {
355 #[inline]
357 #[must_use]
358 pub fn new(store: S, issuer: I) -> Self {
359 let skipper = MethodSkipper::new().skip_all().skip_get(false);
360 Self {
361 store,
362 issuer,
363 skipper: Box::new(skipper),
364 }
365 }
366 #[inline]
368 #[must_use]
369 pub fn skipper(mut self, skipper: impl Skipper) -> Self {
370 self.skipper = Box::new(skipper);
371 self
372 }
373}
374
375#[async_trait]
376impl<S, I> Handler for Cache<S, I>
377where
378 S: CacheStore<Key = I::Key>,
379 I: CacheIssuer,
380{
381 async fn handle(
382 &self,
383 req: &mut Request,
384 depot: &mut Depot,
385 res: &mut Response,
386 ctrl: &mut FlowCtrl,
387 ) {
388 if self.skipper.skipped(req, depot) {
389 ctrl.call_next(req, depot, res).await;
390 return;
391 }
392 let Some(key) = self.issuer.issue(req, depot).await else {
393 return;
394 };
395 let Some(cache) = self.store.load_entry(&key).await else {
396 ctrl.call_next(req, depot, res).await;
397 if !res.body.is_stream() && !res.body.is_error() {
398 let headers = res.headers().clone();
399 let body = TryInto::<CachedBody>::try_into(&res.body);
400 match body {
401 Ok(body) => {
402 let cached_data = CachedEntry::new(res.status_code, headers, body);
403 if let Err(e) = self.store.save_entry(key, cached_data).await {
404 tracing::error!(error = ?e, "cache failed");
405 }
406 }
407 Err(e) => tracing::error!(error = ?e, "cache failed"),
408 }
409 }
410 return;
411 };
412 let CachedEntry {
413 status,
414 headers,
415 body,
416 } = cache;
417 if let Some(status) = status {
418 res.status_code(status);
419 }
420 *res.headers_mut() = headers;
421 *res.body_mut() = body.into();
422 ctrl.skip_rest();
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use std::collections::VecDeque;
429
430 use bytes::Bytes;
431 use salvo_core::http::HeaderMap;
432 use salvo_core::prelude::*;
433 use salvo_core::test::{ResponseExt, TestClient};
434 use time::OffsetDateTime;
435
436 use super::*;
437
438 #[handler]
439 async fn cached() -> String {
440 format!(
441 "Hello World, my birth time is {}",
442 OffsetDateTime::now_utc()
443 )
444 }
445
446 #[tokio::test]
447 async fn test_cache() {
448 let cache = Cache::new(
449 MokaStore::builder()
450 .time_to_live(std::time::Duration::from_secs(5))
451 .build(),
452 RequestIssuer::default(),
453 );
454 let router = Router::new().hoop(cache).goal(cached);
455 let service = Service::new(router);
456
457 let mut res = TestClient::get("http://127.0.0.1:5801")
458 .send(&service)
459 .await;
460 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
461
462 let content0 = res.take_string().await.unwrap();
463
464 let mut res = TestClient::get("http://127.0.0.1:5801")
465 .send(&service)
466 .await;
467 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
468
469 let content1 = res.take_string().await.unwrap();
470 assert_eq!(content0, content1);
471
472 tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
473 let mut res = TestClient::post("http://127.0.0.1:5801")
474 .send(&service)
475 .await;
476 let content2 = res.take_string().await.unwrap();
477
478 assert_ne!(content0, content2);
479 }
480
481 #[test]
483 fn test_request_issuer_new() {
484 let issuer = RequestIssuer::new();
485 assert!(issuer.use_scheme);
486 assert!(issuer.use_authority);
487 assert!(issuer.use_path);
488 assert!(issuer.use_query);
489 assert!(issuer.use_method);
490 }
491
492 #[test]
493 fn test_request_issuer_default() {
494 let issuer = RequestIssuer::default();
495 assert!(issuer.use_scheme);
496 assert!(issuer.use_authority);
497 assert!(issuer.use_path);
498 assert!(issuer.use_query);
499 assert!(issuer.use_method);
500 }
501
502 #[test]
503 fn test_request_issuer_use_scheme() {
504 let issuer = RequestIssuer::new().use_scheme(false);
505 assert!(!issuer.use_scheme);
506 assert!(issuer.use_authority);
507 }
508
509 #[test]
510 fn test_request_issuer_use_authority() {
511 let issuer = RequestIssuer::new().use_authority(false);
512 assert!(issuer.use_scheme);
513 assert!(!issuer.use_authority);
514 }
515
516 #[test]
517 fn test_request_issuer_use_path() {
518 let issuer = RequestIssuer::new().use_path(false);
519 assert!(!issuer.use_path);
520 }
521
522 #[test]
523 fn test_request_issuer_use_query() {
524 let issuer = RequestIssuer::new().use_query(false);
525 assert!(!issuer.use_query);
526 }
527
528 #[test]
529 fn test_request_issuer_use_method() {
530 let issuer = RequestIssuer::new().use_method(false);
531 assert!(!issuer.use_method);
532 }
533
534 #[test]
535 fn test_request_issuer_chain() {
536 let issuer = RequestIssuer::new()
537 .use_scheme(false)
538 .use_authority(false)
539 .use_path(true)
540 .use_query(false)
541 .use_method(true);
542 assert!(!issuer.use_scheme);
543 assert!(!issuer.use_authority);
544 assert!(issuer.use_path);
545 assert!(!issuer.use_query);
546 assert!(issuer.use_method);
547 }
548
549 #[test]
550 fn test_request_issuer_debug() {
551 let issuer = RequestIssuer::new();
552 let debug_str = format!("{issuer:?}");
553 assert!(debug_str.contains("RequestIssuer"));
554 assert!(debug_str.contains("use_scheme"));
555 }
556
557 #[test]
558 fn test_request_issuer_clone() {
559 let issuer = RequestIssuer::new().use_scheme(false);
560 let cloned = issuer.clone();
561 assert_eq!(issuer.use_scheme, cloned.use_scheme);
562 assert_eq!(issuer.use_authority, cloned.use_authority);
563 }
564
565 #[test]
567 fn test_cached_body_none() {
568 let body = CachedBody::None;
569 assert_eq!(body, CachedBody::None);
570 }
571
572 #[test]
573 fn test_cached_body_once() {
574 let bytes = Bytes::from("test data");
575 let body = CachedBody::Once(bytes.clone());
576 assert_eq!(body, CachedBody::Once(bytes));
577 }
578
579 #[test]
580 fn test_cached_body_chunks() {
581 let mut chunks = VecDeque::new();
582 chunks.push_back(Bytes::from("chunk1"));
583 chunks.push_back(Bytes::from("chunk2"));
584 let body = CachedBody::Chunks(chunks.clone());
585 assert_eq!(body, CachedBody::Chunks(chunks));
586 }
587
588 #[test]
589 fn test_cached_body_try_from_res_body_none() {
590 let res_body = ResBody::None;
591 let result: Result<CachedBody, _> = (&res_body).try_into();
592 assert_eq!(result.unwrap(), CachedBody::None);
593 }
594
595 #[test]
596 fn test_cached_body_try_from_res_body_once() {
597 let bytes = Bytes::from("test");
598 let res_body = ResBody::Once(bytes.clone());
599 let result: Result<CachedBody, _> = (&res_body).try_into();
600 assert_eq!(result.unwrap(), CachedBody::Once(bytes));
601 }
602
603 #[test]
604 fn test_cached_body_try_from_res_body_chunks() {
605 let mut chunks = VecDeque::new();
606 chunks.push_back(Bytes::from("chunk1"));
607 chunks.push_back(Bytes::from("chunk2"));
608 let res_body = ResBody::Chunks(chunks.clone());
609 let result: Result<CachedBody, _> = (&res_body).try_into();
610 assert_eq!(result.unwrap(), CachedBody::Chunks(chunks));
611 }
612
613 #[test]
614 fn test_cached_body_into_res_body_none() {
615 let cb = CachedBody::None;
616 let res_body: ResBody = cb.into();
617 assert!(matches!(res_body, ResBody::None));
618 }
619
620 #[test]
621 fn test_cached_body_into_res_body_once() {
622 let bytes = Bytes::from("test");
623 let cb = CachedBody::Once(bytes.clone());
624 let res_body: ResBody = cb.into();
625 assert!(matches!(res_body, ResBody::Once(b) if b == bytes));
626 }
627
628 #[test]
629 fn test_cached_body_into_res_body_chunks() {
630 let mut chunks = VecDeque::new();
631 chunks.push_back(Bytes::from("chunk1"));
632 let cb = CachedBody::Chunks(chunks);
633 let res_body: ResBody = cb.into();
634 assert!(matches!(res_body, ResBody::Chunks(_)));
635 }
636
637 #[test]
638 fn test_cached_body_debug() {
639 let body = CachedBody::None;
640 let debug_str = format!("{body:?}");
641 assert!(debug_str.contains("None"));
642
643 let body = CachedBody::Once(Bytes::from("test"));
644 let debug_str = format!("{body:?}");
645 assert!(debug_str.contains("Once"));
646 }
647
648 #[test]
649 fn test_cached_body_clone() {
650 let body = CachedBody::Once(Bytes::from("test"));
651 let cloned = body.clone();
652 assert_eq!(body, cloned);
653 }
654
655 #[test]
657 fn test_cached_entry_new() {
658 let entry = CachedEntry::new(Some(StatusCode::OK), HeaderMap::new(), CachedBody::None);
659 assert_eq!(entry.status, Some(StatusCode::OK));
660 assert!(entry.headers.is_empty());
661 assert_eq!(entry.body, CachedBody::None);
662 }
663
664 #[test]
665 fn test_cached_entry_status() {
666 let entry = CachedEntry::new(
667 Some(StatusCode::NOT_FOUND),
668 HeaderMap::new(),
669 CachedBody::None,
670 );
671 assert_eq!(entry.status(), Some(StatusCode::NOT_FOUND));
672 }
673
674 #[test]
675 fn test_cached_entry_status_none() {
676 let entry = CachedEntry::new(None, HeaderMap::new(), CachedBody::None);
677 assert_eq!(entry.status(), None);
678 }
679
680 #[test]
681 fn test_cached_entry_headers() {
682 let mut headers = HeaderMap::new();
683 headers.insert("Content-Type", "application/json".parse().unwrap());
684 let entry = CachedEntry::new(Some(StatusCode::OK), headers.clone(), CachedBody::None);
685 assert_eq!(entry.headers().len(), 1);
686 assert!(entry.headers().contains_key("Content-Type"));
687 }
688
689 #[test]
690 fn test_cached_entry_body() {
691 let body = CachedBody::Once(Bytes::from("test body"));
692 let entry = CachedEntry::new(Some(StatusCode::OK), HeaderMap::new(), body.clone());
693 assert_eq!(entry.body(), &body);
694 }
695
696 #[test]
697 fn test_cached_entry_debug() {
698 let entry = CachedEntry::new(Some(StatusCode::OK), HeaderMap::new(), CachedBody::None);
699 let debug_str = format!("{entry:?}");
700 assert!(debug_str.contains("CachedEntry"));
701 assert!(debug_str.contains("status"));
702 }
703
704 #[test]
705 fn test_cached_entry_clone() {
706 let entry = CachedEntry::new(
707 Some(StatusCode::OK),
708 HeaderMap::new(),
709 CachedBody::Once(Bytes::from("test")),
710 );
711 let cloned = entry.clone();
712 assert_eq!(entry.status, cloned.status);
713 assert_eq!(entry.body, cloned.body);
714 }
715
716 #[test]
718 fn test_cache_new() {
719 let cache = Cache::new(MokaStore::<String>::new(100), RequestIssuer::default());
720 assert!(format!("{cache:?}").contains("Cache"));
721 }
722
723 #[test]
724 fn test_cache_debug() {
725 let cache = Cache::new(MokaStore::<String>::new(100), RequestIssuer::default());
726 let debug_str = format!("{cache:?}");
727 assert!(debug_str.contains("Cache"));
728 assert!(debug_str.contains("store"));
729 assert!(debug_str.contains("issuer"));
730 }
731
732 #[tokio::test]
733 async fn test_cache_same_path_same_content() {
734 let cache = Cache::new(
735 MokaStore::builder()
736 .time_to_live(std::time::Duration::from_secs(60))
737 .build(),
738 RequestIssuer::default(),
739 );
740 let router = Router::new().hoop(cache).goal(cached);
741 let service = Service::new(router);
742
743 let mut res1 = TestClient::get("http://127.0.0.1:5801/same-path")
744 .send(&service)
745 .await;
746 let content1 = res1.take_string().await.unwrap();
747
748 let mut res2 = TestClient::get("http://127.0.0.1:5801/same-path")
749 .send(&service)
750 .await;
751 let content2 = res2.take_string().await.unwrap();
752
753 assert_eq!(content1, content2);
755 }
756}