1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
15#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
16#![cfg_attr(docsrs, feature(doc_cfg))]
17
18use std::borrow::Borrow;
19use std::collections::VecDeque;
20use std::error::Error as StdError;
21use std::fmt::{self, Debug, Formatter};
22use std::hash::Hash;
23
24use bytes::Bytes;
25use salvo_core::handler::Skipper;
26use salvo_core::http::{HeaderMap, ResBody, StatusCode};
27use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
28
29mod skipper;
30pub use skipper::MethodSkipper;
31
32#[macro_use]
33mod cfg;
34
35cfg_feature! {
36 #![feature = "moka-store"]
37
38 pub mod moka_store;
39 pub use moka_store::{MokaStore};
40}
41
42pub trait CacheIssuer: Send + Sync + 'static {
44 type Key: Hash + Eq + Send + Sync + 'static;
46 fn issue(
48 &self,
49 req: &mut Request,
50 depot: &Depot,
51 ) -> impl Future<Output = Option<Self::Key>> + Send;
52}
53impl<F, K> CacheIssuer for F
54where
55 F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
56 K: Hash + Eq + Send + Sync + 'static,
57{
58 type Key = K;
59 async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
60 (self)(req, depot)
61 }
62}
63
64#[derive(Clone, Debug)]
66pub struct RequestIssuer {
67 use_scheme: bool,
68 use_authority: bool,
69 use_path: bool,
70 use_query: bool,
71 use_method: bool,
72}
73impl Default for RequestIssuer {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78impl RequestIssuer {
79 #[must_use]
81 pub fn new() -> Self {
82 Self {
83 use_scheme: true,
84 use_authority: true,
85 use_path: true,
86 use_query: true,
87 use_method: true,
88 }
89 }
90 #[must_use]
92 pub fn use_scheme(mut self, value: bool) -> Self {
93 self.use_scheme = value;
94 self
95 }
96 #[must_use]
98 pub fn use_authority(mut self, value: bool) -> Self {
99 self.use_authority = value;
100 self
101 }
102 #[must_use]
104 pub fn use_path(mut self, value: bool) -> Self {
105 self.use_path = value;
106 self
107 }
108 #[must_use]
110 pub fn use_query(mut self, value: bool) -> Self {
111 self.use_query = value;
112 self
113 }
114 #[must_use]
116 pub fn use_method(mut self, value: bool) -> Self {
117 self.use_method = value;
118 self
119 }
120}
121
122impl CacheIssuer for RequestIssuer {
123 type Key = String;
124 async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
125 let mut key = String::new();
126 if self.use_scheme
127 && let Some(scheme) = req.uri().scheme_str()
128 {
129 key.push_str(scheme);
130 key.push_str("://");
131 }
132 if self.use_authority
133 && let Some(authority) = req.uri().authority()
134 {
135 key.push_str(authority.as_str());
136 }
137 if self.use_path {
138 key.push_str(req.uri().path());
139 }
140 if self.use_query
141 && let Some(query) = req.uri().query()
142 {
143 key.push('?');
144 key.push_str(query);
145 }
146 if self.use_method {
147 key.push('|');
148 key.push_str(req.method().as_str());
149 }
150 Some(key)
151 }
152}
153
154pub trait CacheStore: Send + Sync + 'static {
156 type Error: StdError + Sync + Send + 'static;
158 type Key: Hash + Eq + Send + Clone + 'static;
160 fn load_entry<Q>(&self, key: &Q) -> impl Future<Output = Option<CachedEntry>> + Send
162 where
163 Self::Key: Borrow<Q>,
164 Q: Hash + Eq + Sync;
165 fn save_entry(
167 &self,
168 key: Self::Key,
169 data: CachedEntry,
170 ) -> impl Future<Output = Result<(), Self::Error>> + Send;
171}
172
173#[derive(Clone, Debug, PartialEq)]
178#[non_exhaustive]
179pub enum CachedBody {
180 None,
182 Once(Bytes),
184 Chunks(VecDeque<Bytes>),
186}
187impl TryFrom<&ResBody> for CachedBody {
188 type Error = Error;
189 fn try_from(body: &ResBody) -> Result<Self, Self::Error> {
190 match body {
191 ResBody::None => Ok(Self::None),
192 ResBody::Once(bytes) => Ok(Self::Once(bytes.to_owned())),
193 ResBody::Chunks(chunks) => Ok(Self::Chunks(chunks.to_owned())),
194 _ => Err(Error::other("unsupported body type")),
195 }
196 }
197}
198impl From<CachedBody> for ResBody {
199 fn from(body: CachedBody) -> Self {
200 match body {
201 CachedBody::None => Self::None,
202 CachedBody::Once(bytes) => Self::Once(bytes),
203 CachedBody::Chunks(chunks) => Self::Chunks(chunks),
204 }
205 }
206}
207
208#[derive(Clone, Debug)]
210#[non_exhaustive]
211pub struct CachedEntry {
212 pub status: Option<StatusCode>,
214 pub headers: HeaderMap,
216 pub body: CachedBody,
220}
221impl CachedEntry {
222 pub fn new(status: Option<StatusCode>, headers: HeaderMap, body: CachedBody) -> Self {
224 Self {
225 status,
226 headers,
227 body,
228 }
229 }
230
231 pub fn status(&self) -> Option<StatusCode> {
233 self.status
234 }
235
236 pub fn headers(&self) -> &HeaderMap {
238 &self.headers
239 }
240
241 pub fn body(&self) -> &CachedBody {
245 &self.body
246 }
247}
248
249#[non_exhaustive]
268pub struct Cache<S, I> {
269 pub store: S,
271 pub issuer: I,
273 pub skipper: Box<dyn Skipper>,
275}
276impl<S, I> Debug for Cache<S, I>
277where
278 S: Debug,
279 I: Debug,
280{
281 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
282 f.debug_struct("Cache")
283 .field("store", &self.store)
284 .field("issuer", &self.issuer)
285 .finish()
286 }
287}
288
289impl<S, I> Cache<S, I> {
290 #[inline]
292 #[must_use]
293 pub fn new(store: S, issuer: I) -> Self {
294 let skipper = MethodSkipper::new().skip_all().skip_get(false);
295 Self {
296 store,
297 issuer,
298 skipper: Box::new(skipper),
299 }
300 }
301 #[inline]
303 #[must_use]
304 pub fn skipper(mut self, skipper: impl Skipper) -> Self {
305 self.skipper = Box::new(skipper);
306 self
307 }
308}
309
310#[async_trait]
311impl<S, I> Handler for Cache<S, I>
312where
313 S: CacheStore<Key = I::Key>,
314 I: CacheIssuer,
315{
316 async fn handle(
317 &self,
318 req: &mut Request,
319 depot: &mut Depot,
320 res: &mut Response,
321 ctrl: &mut FlowCtrl,
322 ) {
323 if self.skipper.skipped(req, depot) {
324 return;
325 }
326 let Some(key) = self.issuer.issue(req, depot).await else {
327 return;
328 };
329 let Some(cache) = self.store.load_entry(&key).await else {
330 ctrl.call_next(req, depot, res).await;
331 if !res.body.is_stream() && !res.body.is_error() {
332 let headers = res.headers().clone();
333 let body = TryInto::<CachedBody>::try_into(&res.body);
334 match body {
335 Ok(body) => {
336 let cached_data = CachedEntry::new(res.status_code, headers, body);
337 if let Err(e) = self.store.save_entry(key, cached_data).await {
338 tracing::error!(error = ?e, "cache failed");
339 }
340 }
341 Err(e) => tracing::error!(error = ?e, "cache failed"),
342 }
343 }
344 return;
345 };
346 let CachedEntry {
347 status,
348 headers,
349 body,
350 } = cache;
351 if let Some(status) = status {
352 res.status_code(status);
353 }
354 *res.headers_mut() = headers;
355 *res.body_mut() = body.into();
356 ctrl.skip_rest();
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use salvo_core::prelude::*;
363 use salvo_core::test::{ResponseExt, TestClient};
364 use time::OffsetDateTime;
365
366 use super::*;
367
368 #[handler]
369 async fn cached() -> String {
370 format!(
371 "Hello World, my birth time is {}",
372 OffsetDateTime::now_utc()
373 )
374 }
375
376 #[tokio::test]
377 async fn test_cache() {
378 let cache = Cache::new(
379 MokaStore::builder()
380 .time_to_live(std::time::Duration::from_secs(5))
381 .build(),
382 RequestIssuer::default(),
383 );
384 let router = Router::new().hoop(cache).goal(cached);
385 let service = Service::new(router);
386
387 let mut res = TestClient::get("http://127.0.0.1:5801")
388 .send(&service)
389 .await;
390 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
391
392 let content0 = res.take_string().await.unwrap();
393
394 let mut res = TestClient::get("http://127.0.0.1:5801")
395 .send(&service)
396 .await;
397 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
398
399 let content1 = res.take_string().await.unwrap();
400 assert_eq!(content0, content1);
401
402 tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
403 let mut res = TestClient::post("http://127.0.0.1:5801")
404 .send(&service)
405 .await;
406 let content2 = res.take_string().await.unwrap();
407
408 assert_ne!(content0, content2);
409 }
410}