1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
79#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
80#![cfg_attr(docsrs, feature(doc_cfg))]
81
82use std::borrow::Borrow;
83use std::collections::VecDeque;
84use std::error::Error as StdError;
85use std::fmt::{self, Debug, Formatter};
86use std::hash::Hash;
87
88use bytes::Bytes;
89use salvo_core::handler::Skipper;
90use salvo_core::http::{HeaderMap, ResBody, StatusCode};
91use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
92
93mod skipper;
94pub use skipper::MethodSkipper;
95
96#[macro_use]
97mod cfg;
98
99cfg_feature! {
100 #![feature = "moka-store"]
101
102 pub mod moka_store;
103 pub use moka_store::{MokaStore};
104}
105
106pub trait CacheIssuer: Send + Sync + 'static {
108 type Key: Hash + Eq + Send + Sync + 'static;
110 fn issue(
112 &self,
113 req: &mut Request,
114 depot: &Depot,
115 ) -> impl Future<Output = Option<Self::Key>> + Send;
116}
117impl<F, K> CacheIssuer for F
118where
119 F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
120 K: Hash + Eq + Send + Sync + 'static,
121{
122 type Key = K;
123 async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
124 (self)(req, depot)
125 }
126}
127
128#[derive(Clone, Debug)]
130pub struct RequestIssuer {
131 use_scheme: bool,
132 use_authority: bool,
133 use_path: bool,
134 use_query: bool,
135 use_method: bool,
136}
137impl Default for RequestIssuer {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142impl RequestIssuer {
143 #[must_use]
145 pub fn new() -> Self {
146 Self {
147 use_scheme: true,
148 use_authority: true,
149 use_path: true,
150 use_query: true,
151 use_method: true,
152 }
153 }
154 #[must_use]
156 pub fn use_scheme(mut self, value: bool) -> Self {
157 self.use_scheme = value;
158 self
159 }
160 #[must_use]
162 pub fn use_authority(mut self, value: bool) -> Self {
163 self.use_authority = value;
164 self
165 }
166 #[must_use]
168 pub fn use_path(mut self, value: bool) -> Self {
169 self.use_path = value;
170 self
171 }
172 #[must_use]
174 pub fn use_query(mut self, value: bool) -> Self {
175 self.use_query = value;
176 self
177 }
178 #[must_use]
180 pub fn use_method(mut self, value: bool) -> Self {
181 self.use_method = value;
182 self
183 }
184}
185
186impl CacheIssuer for RequestIssuer {
187 type Key = String;
188 async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
189 let mut key = String::new();
190 if self.use_scheme
191 && let Some(scheme) = req.uri().scheme_str()
192 {
193 key.push_str(scheme);
194 key.push_str("://");
195 }
196 if self.use_authority
197 && let Some(authority) = req.uri().authority()
198 {
199 key.push_str(authority.as_str());
200 }
201 if self.use_path {
202 key.push_str(req.uri().path());
203 }
204 if self.use_query
205 && let Some(query) = req.uri().query()
206 {
207 key.push('?');
208 key.push_str(query);
209 }
210 if self.use_method {
211 key.push('|');
212 key.push_str(req.method().as_str());
213 }
214 Some(key)
215 }
216}
217
218pub trait CacheStore: Send + Sync + 'static {
220 type Error: StdError + Sync + Send + 'static;
222 type Key: Hash + Eq + Send + Clone + 'static;
224 fn load_entry<Q>(&self, key: &Q) -> impl Future<Output = Option<CachedEntry>> + Send
226 where
227 Self::Key: Borrow<Q>,
228 Q: Hash + Eq + Sync;
229 fn save_entry(
231 &self,
232 key: Self::Key,
233 data: CachedEntry,
234 ) -> impl Future<Output = Result<(), Self::Error>> + Send;
235}
236
237#[derive(Clone, Debug, PartialEq)]
242#[non_exhaustive]
243pub enum CachedBody {
244 None,
246 Once(Bytes),
248 Chunks(VecDeque<Bytes>),
250}
251impl TryFrom<&ResBody> for CachedBody {
252 type Error = Error;
253 fn try_from(body: &ResBody) -> Result<Self, Self::Error> {
254 match body {
255 ResBody::None => Ok(Self::None),
256 ResBody::Once(bytes) => Ok(Self::Once(bytes.to_owned())),
257 ResBody::Chunks(chunks) => Ok(Self::Chunks(chunks.to_owned())),
258 _ => Err(Error::other("unsupported body type")),
259 }
260 }
261}
262impl From<CachedBody> for ResBody {
263 fn from(body: CachedBody) -> Self {
264 match body {
265 CachedBody::None => Self::None,
266 CachedBody::Once(bytes) => Self::Once(bytes),
267 CachedBody::Chunks(chunks) => Self::Chunks(chunks),
268 }
269 }
270}
271
272#[derive(Clone, Debug)]
274#[non_exhaustive]
275pub struct CachedEntry {
276 pub status: Option<StatusCode>,
278 pub headers: HeaderMap,
280 pub body: CachedBody,
284}
285impl CachedEntry {
286 pub fn new(status: Option<StatusCode>, headers: HeaderMap, body: CachedBody) -> Self {
288 Self {
289 status,
290 headers,
291 body,
292 }
293 }
294
295 pub fn status(&self) -> Option<StatusCode> {
297 self.status
298 }
299
300 pub fn headers(&self) -> &HeaderMap {
302 &self.headers
303 }
304
305 pub fn body(&self) -> &CachedBody {
309 &self.body
310 }
311}
312
313#[non_exhaustive]
332pub struct Cache<S, I> {
333 pub store: S,
335 pub issuer: I,
337 pub skipper: Box<dyn Skipper>,
339}
340impl<S, I> Debug for Cache<S, I>
341where
342 S: Debug,
343 I: Debug,
344{
345 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
346 f.debug_struct("Cache")
347 .field("store", &self.store)
348 .field("issuer", &self.issuer)
349 .finish()
350 }
351}
352
353impl<S, I> Cache<S, I> {
354 #[inline]
356 #[must_use]
357 pub fn new(store: S, issuer: I) -> Self {
358 let skipper = MethodSkipper::new().skip_all().skip_get(false);
359 Self {
360 store,
361 issuer,
362 skipper: Box::new(skipper),
363 }
364 }
365 #[inline]
367 #[must_use]
368 pub fn skipper(mut self, skipper: impl Skipper) -> Self {
369 self.skipper = Box::new(skipper);
370 self
371 }
372}
373
374#[async_trait]
375impl<S, I> Handler for Cache<S, I>
376where
377 S: CacheStore<Key = I::Key>,
378 I: CacheIssuer,
379{
380 async fn handle(
381 &self,
382 req: &mut Request,
383 depot: &mut Depot,
384 res: &mut Response,
385 ctrl: &mut FlowCtrl,
386 ) {
387 if self.skipper.skipped(req, depot) {
388 return;
389 }
390 let Some(key) = self.issuer.issue(req, depot).await else {
391 return;
392 };
393 let Some(cache) = self.store.load_entry(&key).await else {
394 ctrl.call_next(req, depot, res).await;
395 if !res.body.is_stream() && !res.body.is_error() {
396 let headers = res.headers().clone();
397 let body = TryInto::<CachedBody>::try_into(&res.body);
398 match body {
399 Ok(body) => {
400 let cached_data = CachedEntry::new(res.status_code, headers, body);
401 if let Err(e) = self.store.save_entry(key, cached_data).await {
402 tracing::error!(error = ?e, "cache failed");
403 }
404 }
405 Err(e) => tracing::error!(error = ?e, "cache failed"),
406 }
407 }
408 return;
409 };
410 let CachedEntry {
411 status,
412 headers,
413 body,
414 } = cache;
415 if let Some(status) = status {
416 res.status_code(status);
417 }
418 *res.headers_mut() = headers;
419 *res.body_mut() = body.into();
420 ctrl.skip_rest();
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use std::collections::VecDeque;
427
428 use bytes::Bytes;
429 use salvo_core::http::HeaderMap;
430 use salvo_core::prelude::*;
431 use salvo_core::test::{ResponseExt, TestClient};
432 use time::OffsetDateTime;
433
434 use super::*;
435
436 #[handler]
437 async fn cached() -> String {
438 format!(
439 "Hello World, my birth time is {}",
440 OffsetDateTime::now_utc()
441 )
442 }
443
444 #[tokio::test]
445 async fn test_cache() {
446 let cache = Cache::new(
447 MokaStore::builder()
448 .time_to_live(std::time::Duration::from_secs(5))
449 .build(),
450 RequestIssuer::default(),
451 );
452 let router = Router::new().hoop(cache).goal(cached);
453 let service = Service::new(router);
454
455 let mut res = TestClient::get("http://127.0.0.1:5801")
456 .send(&service)
457 .await;
458 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
459
460 let content0 = res.take_string().await.unwrap();
461
462 let mut res = TestClient::get("http://127.0.0.1:5801")
463 .send(&service)
464 .await;
465 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
466
467 let content1 = res.take_string().await.unwrap();
468 assert_eq!(content0, content1);
469
470 tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
471 let mut res = TestClient::post("http://127.0.0.1:5801")
472 .send(&service)
473 .await;
474 let content2 = res.take_string().await.unwrap();
475
476 assert_ne!(content0, content2);
477 }
478
479 #[test]
481 fn test_request_issuer_new() {
482 let issuer = RequestIssuer::new();
483 assert!(issuer.use_scheme);
484 assert!(issuer.use_authority);
485 assert!(issuer.use_path);
486 assert!(issuer.use_query);
487 assert!(issuer.use_method);
488 }
489
490 #[test]
491 fn test_request_issuer_default() {
492 let issuer = RequestIssuer::default();
493 assert!(issuer.use_scheme);
494 assert!(issuer.use_authority);
495 assert!(issuer.use_path);
496 assert!(issuer.use_query);
497 assert!(issuer.use_method);
498 }
499
500 #[test]
501 fn test_request_issuer_use_scheme() {
502 let issuer = RequestIssuer::new().use_scheme(false);
503 assert!(!issuer.use_scheme);
504 assert!(issuer.use_authority);
505 }
506
507 #[test]
508 fn test_request_issuer_use_authority() {
509 let issuer = RequestIssuer::new().use_authority(false);
510 assert!(issuer.use_scheme);
511 assert!(!issuer.use_authority);
512 }
513
514 #[test]
515 fn test_request_issuer_use_path() {
516 let issuer = RequestIssuer::new().use_path(false);
517 assert!(!issuer.use_path);
518 }
519
520 #[test]
521 fn test_request_issuer_use_query() {
522 let issuer = RequestIssuer::new().use_query(false);
523 assert!(!issuer.use_query);
524 }
525
526 #[test]
527 fn test_request_issuer_use_method() {
528 let issuer = RequestIssuer::new().use_method(false);
529 assert!(!issuer.use_method);
530 }
531
532 #[test]
533 fn test_request_issuer_chain() {
534 let issuer = RequestIssuer::new()
535 .use_scheme(false)
536 .use_authority(false)
537 .use_path(true)
538 .use_query(false)
539 .use_method(true);
540 assert!(!issuer.use_scheme);
541 assert!(!issuer.use_authority);
542 assert!(issuer.use_path);
543 assert!(!issuer.use_query);
544 assert!(issuer.use_method);
545 }
546
547 #[test]
548 fn test_request_issuer_debug() {
549 let issuer = RequestIssuer::new();
550 let debug_str = format!("{:?}", issuer);
551 assert!(debug_str.contains("RequestIssuer"));
552 assert!(debug_str.contains("use_scheme"));
553 }
554
555 #[test]
556 fn test_request_issuer_clone() {
557 let issuer = RequestIssuer::new().use_scheme(false);
558 let cloned = issuer.clone();
559 assert_eq!(issuer.use_scheme, cloned.use_scheme);
560 assert_eq!(issuer.use_authority, cloned.use_authority);
561 }
562
563 #[test]
565 fn test_cached_body_none() {
566 let body = CachedBody::None;
567 assert_eq!(body, CachedBody::None);
568 }
569
570 #[test]
571 fn test_cached_body_once() {
572 let bytes = Bytes::from("test data");
573 let body = CachedBody::Once(bytes.clone());
574 assert_eq!(body, CachedBody::Once(bytes));
575 }
576
577 #[test]
578 fn test_cached_body_chunks() {
579 let mut chunks = VecDeque::new();
580 chunks.push_back(Bytes::from("chunk1"));
581 chunks.push_back(Bytes::from("chunk2"));
582 let body = CachedBody::Chunks(chunks.clone());
583 assert_eq!(body, CachedBody::Chunks(chunks));
584 }
585
586 #[test]
587 fn test_cached_body_try_from_res_body_none() {
588 let res_body = ResBody::None;
589 let result: Result<CachedBody, _> = (&res_body).try_into();
590 assert_eq!(result.unwrap(), CachedBody::None);
591 }
592
593 #[test]
594 fn test_cached_body_try_from_res_body_once() {
595 let bytes = Bytes::from("test");
596 let res_body = ResBody::Once(bytes.clone());
597 let result: Result<CachedBody, _> = (&res_body).try_into();
598 assert_eq!(result.unwrap(), CachedBody::Once(bytes));
599 }
600
601 #[test]
602 fn test_cached_body_try_from_res_body_chunks() {
603 let mut chunks = VecDeque::new();
604 chunks.push_back(Bytes::from("chunk1"));
605 chunks.push_back(Bytes::from("chunk2"));
606 let res_body = ResBody::Chunks(chunks.clone());
607 let result: Result<CachedBody, _> = (&res_body).try_into();
608 assert_eq!(result.unwrap(), CachedBody::Chunks(chunks));
609 }
610
611 #[test]
612 fn test_cached_body_into_res_body_none() {
613 let cb = CachedBody::None;
614 let res_body: ResBody = cb.into();
615 assert!(matches!(res_body, ResBody::None));
616 }
617
618 #[test]
619 fn test_cached_body_into_res_body_once() {
620 let bytes = Bytes::from("test");
621 let cb = CachedBody::Once(bytes.clone());
622 let res_body: ResBody = cb.into();
623 assert!(matches!(res_body, ResBody::Once(b) if b == bytes));
624 }
625
626 #[test]
627 fn test_cached_body_into_res_body_chunks() {
628 let mut chunks = VecDeque::new();
629 chunks.push_back(Bytes::from("chunk1"));
630 let cb = CachedBody::Chunks(chunks);
631 let res_body: ResBody = cb.into();
632 assert!(matches!(res_body, ResBody::Chunks(_)));
633 }
634
635 #[test]
636 fn test_cached_body_debug() {
637 let body = CachedBody::None;
638 let debug_str = format!("{:?}", body);
639 assert!(debug_str.contains("None"));
640
641 let body = CachedBody::Once(Bytes::from("test"));
642 let debug_str = format!("{:?}", body);
643 assert!(debug_str.contains("Once"));
644 }
645
646 #[test]
647 fn test_cached_body_clone() {
648 let body = CachedBody::Once(Bytes::from("test"));
649 let cloned = body.clone();
650 assert_eq!(body, cloned);
651 }
652
653 #[test]
655 fn test_cached_entry_new() {
656 let entry = CachedEntry::new(Some(StatusCode::OK), HeaderMap::new(), CachedBody::None);
657 assert_eq!(entry.status, Some(StatusCode::OK));
658 assert!(entry.headers.is_empty());
659 assert_eq!(entry.body, CachedBody::None);
660 }
661
662 #[test]
663 fn test_cached_entry_status() {
664 let entry = CachedEntry::new(
665 Some(StatusCode::NOT_FOUND),
666 HeaderMap::new(),
667 CachedBody::None,
668 );
669 assert_eq!(entry.status(), Some(StatusCode::NOT_FOUND));
670 }
671
672 #[test]
673 fn test_cached_entry_status_none() {
674 let entry = CachedEntry::new(None, HeaderMap::new(), CachedBody::None);
675 assert_eq!(entry.status(), None);
676 }
677
678 #[test]
679 fn test_cached_entry_headers() {
680 let mut headers = HeaderMap::new();
681 headers.insert("Content-Type", "application/json".parse().unwrap());
682 let entry = CachedEntry::new(Some(StatusCode::OK), headers.clone(), CachedBody::None);
683 assert_eq!(entry.headers().len(), 1);
684 assert!(entry.headers().contains_key("Content-Type"));
685 }
686
687 #[test]
688 fn test_cached_entry_body() {
689 let body = CachedBody::Once(Bytes::from("test body"));
690 let entry = CachedEntry::new(Some(StatusCode::OK), HeaderMap::new(), body.clone());
691 assert_eq!(entry.body(), &body);
692 }
693
694 #[test]
695 fn test_cached_entry_debug() {
696 let entry = CachedEntry::new(Some(StatusCode::OK), HeaderMap::new(), CachedBody::None);
697 let debug_str = format!("{:?}", entry);
698 assert!(debug_str.contains("CachedEntry"));
699 assert!(debug_str.contains("status"));
700 }
701
702 #[test]
703 fn test_cached_entry_clone() {
704 let entry = CachedEntry::new(
705 Some(StatusCode::OK),
706 HeaderMap::new(),
707 CachedBody::Once(Bytes::from("test")),
708 );
709 let cloned = entry.clone();
710 assert_eq!(entry.status, cloned.status);
711 assert_eq!(entry.body, cloned.body);
712 }
713
714 #[test]
716 fn test_cache_new() {
717 let cache = Cache::new(MokaStore::<String>::new(100), RequestIssuer::default());
718 assert!(format!("{:?}", cache).contains("Cache"));
719 }
720
721 #[test]
722 fn test_cache_debug() {
723 let cache = Cache::new(MokaStore::<String>::new(100), RequestIssuer::default());
724 let debug_str = format!("{:?}", cache);
725 assert!(debug_str.contains("Cache"));
726 assert!(debug_str.contains("store"));
727 assert!(debug_str.contains("issuer"));
728 }
729
730 #[tokio::test]
731 async fn test_cache_same_path_same_content() {
732 let cache = Cache::new(
733 MokaStore::builder()
734 .time_to_live(std::time::Duration::from_secs(60))
735 .build(),
736 RequestIssuer::default(),
737 );
738 let router = Router::new().hoop(cache).goal(cached);
739 let service = Service::new(router);
740
741 let mut res1 = TestClient::get("http://127.0.0.1:5801/same-path")
742 .send(&service)
743 .await;
744 let content1 = res1.take_string().await.unwrap();
745
746 let mut res2 = TestClient::get("http://127.0.0.1:5801/same-path")
747 .send(&service)
748 .await;
749 let content2 = res2.take_string().await.unwrap();
750
751 assert_eq!(content1, content2);
753 }
754}