1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
15#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
16#![cfg_attr(docsrs, feature(doc_cfg))]
17
18use std::borrow::Borrow;
19use std::collections::VecDeque;
20use std::error::Error as StdError;
21use std::hash::Hash;
22
23use bytes::Bytes;
24use salvo_core::handler::Skipper;
25use salvo_core::http::{HeaderMap, ResBody, StatusCode};
26use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
27
28mod skipper;
29pub use skipper::MethodSkipper;
30
31#[macro_use]
32mod cfg;
33
34cfg_feature! {
35 #![feature = "moka-store"]
36
37 pub mod moka_store;
38 pub use moka_store::{MokaStore};
39}
40
41pub trait CacheIssuer: Send + Sync + 'static {
43 type Key: Hash + Eq + Send + Sync + 'static;
45 fn issue(
47 &self,
48 req: &mut Request,
49 depot: &Depot,
50 ) -> impl Future<Output = Option<Self::Key>> + Send;
51}
52impl<F, K> CacheIssuer for F
53where
54 F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
55 K: Hash + Eq + Send + Sync + 'static,
56{
57 type Key = K;
58 async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
59 (self)(req, depot)
60 }
61}
62
63#[derive(Clone, Debug)]
65pub struct RequestIssuer {
66 use_scheme: bool,
67 use_authority: bool,
68 use_path: bool,
69 use_query: bool,
70 use_method: bool,
71}
72impl Default for RequestIssuer {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77impl RequestIssuer {
78 pub fn new() -> Self {
80 Self {
81 use_scheme: true,
82 use_authority: true,
83 use_path: true,
84 use_query: true,
85 use_method: true,
86 }
87 }
88 pub fn use_scheme(mut self, value: bool) -> Self {
90 self.use_scheme = value;
91 self
92 }
93 pub fn use_authority(mut self, value: bool) -> Self {
95 self.use_authority = value;
96 self
97 }
98 pub fn use_path(mut self, value: bool) -> Self {
100 self.use_path = value;
101 self
102 }
103 pub fn use_query(mut self, value: bool) -> Self {
105 self.use_query = value;
106 self
107 }
108 pub fn use_method(mut self, value: bool) -> Self {
110 self.use_method = value;
111 self
112 }
113}
114
115impl CacheIssuer for RequestIssuer {
116 type Key = String;
117 async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
118 let mut key = String::new();
119 if self.use_scheme {
120 if let Some(scheme) = req.uri().scheme_str() {
121 key.push_str(scheme);
122 key.push_str("://");
123 }
124 }
125 if self.use_authority {
126 if let Some(authority) = req.uri().authority() {
127 key.push_str(authority.as_str());
128 }
129 }
130 if self.use_path {
131 key.push_str(req.uri().path());
132 }
133 if self.use_query {
134 if let Some(query) = req.uri().query() {
135 key.push('?');
136 key.push_str(query);
137 }
138 }
139 if self.use_method {
140 key.push('|');
141 key.push_str(req.method().as_str());
142 }
143 Some(key)
144 }
145}
146
147pub trait CacheStore: Send + Sync + 'static {
149 type Error: StdError + Sync + Send + 'static;
151 type Key: Hash + Eq + Send + Clone + 'static;
153 fn load_entry<Q>(&self, key: &Q) -> impl Future<Output = Option<CachedEntry>> + Send
155 where
156 Self::Key: Borrow<Q>,
157 Q: Hash + Eq + Sync;
158 fn save_entry(
160 &self,
161 key: Self::Key,
162 data: CachedEntry,
163 ) -> impl Future<Output = Result<(), Self::Error>> + Send;
164}
165
166#[derive(Clone, Debug)]
171#[non_exhaustive]
172pub enum CachedBody {
173 None,
175 Once(Bytes),
177 Chunks(VecDeque<Bytes>),
179}
180impl TryFrom<&ResBody> for CachedBody {
181 type Error = Error;
182 fn try_from(body: &ResBody) -> Result<Self, Self::Error> {
183 match body {
184 ResBody::None => Ok(Self::None),
185 ResBody::Once(bytes) => Ok(Self::Once(bytes.to_owned())),
186 ResBody::Chunks(chunks) => Ok(Self::Chunks(chunks.to_owned())),
187 _ => Err(Error::other("unsupported body type")),
188 }
189 }
190}
191impl From<CachedBody> for ResBody {
192 fn from(body: CachedBody) -> Self {
193 match body {
194 CachedBody::None => Self::None,
195 CachedBody::Once(bytes) => Self::Once(bytes),
196 CachedBody::Chunks(chunks) => Self::Chunks(chunks),
197 }
198 }
199}
200
201#[derive(Clone, Debug)]
203#[non_exhaustive]
204pub struct CachedEntry {
205 pub status: Option<StatusCode>,
207 pub headers: HeaderMap,
209 pub body: CachedBody,
213}
214impl CachedEntry {
215 pub fn new(status: Option<StatusCode>, headers: HeaderMap, body: CachedBody) -> Self {
217 Self {
218 status,
219 headers,
220 body,
221 }
222 }
223
224 pub fn status(&self) -> Option<StatusCode> {
226 self.status
227 }
228
229 pub fn headers(&self) -> &HeaderMap {
231 &self.headers
232 }
233
234 pub fn body(&self) -> &CachedBody {
238 &self.body
239 }
240}
241
242#[non_exhaustive]
259pub struct Cache<S, I> {
260 pub store: S,
262 pub issuer: I,
264 pub skipper: Box<dyn Skipper>,
266}
267
268impl<S, I> Cache<S, I> {
269 #[inline]
271 pub fn new(store: S, issuer: I) -> Self {
272 let skipper = MethodSkipper::new().skip_all().skip_get(false);
273 Cache {
274 store,
275 issuer,
276 skipper: Box::new(skipper),
277 }
278 }
279 #[inline]
281 pub fn skipper(mut self, skipper: impl Skipper) -> Self {
282 self.skipper = Box::new(skipper);
283 self
284 }
285}
286
287#[async_trait]
288impl<S, I> Handler for Cache<S, I>
289where
290 S: CacheStore<Key = I::Key>,
291 I: CacheIssuer,
292{
293 async fn handle(
294 &self,
295 req: &mut Request,
296 depot: &mut Depot,
297 res: &mut Response,
298 ctrl: &mut FlowCtrl,
299 ) {
300 if self.skipper.skipped(req, depot) {
301 return;
302 }
303 let key = match self.issuer.issue(req, depot).await {
304 Some(key) => key,
305 None => {
306 return;
307 }
308 };
309 let cache = match self.store.load_entry(&key).await {
310 Some(cache) => cache,
311 None => {
312 ctrl.call_next(req, depot, res).await;
313 if !res.body.is_stream() && !res.body.is_error() {
314 let headers = res.headers().clone();
315 let body = TryInto::<CachedBody>::try_into(&res.body);
316 match body {
317 Ok(body) => {
318 let cached_data = CachedEntry::new(res.status_code, headers, body);
319 if let Err(e) = self.store.save_entry(key, cached_data).await {
320 tracing::error!(error = ?e, "cache failed");
321 }
322 }
323 Err(e) => tracing::error!(error = ?e, "cache failed"),
324 }
325 }
326 return;
327 }
328 };
329 let CachedEntry {
330 status,
331 headers,
332 body,
333 } = cache;
334 if let Some(status) = status {
335 res.status_code(status);
336 }
337 *res.headers_mut() = headers;
338 *res.body_mut() = body.into();
339 ctrl.skip_rest();
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use salvo_core::prelude::*;
347 use salvo_core::test::{ResponseExt, TestClient};
348 use time::OffsetDateTime;
349
350 #[handler]
351 async fn cached() -> String {
352 format!(
353 "Hello World, my birth time is {}",
354 OffsetDateTime::now_utc()
355 )
356 }
357
358 #[tokio::test]
359 async fn test_cache() {
360 let cache = Cache::new(
361 MokaStore::builder()
362 .time_to_live(std::time::Duration::from_secs(5))
363 .build(),
364 RequestIssuer::default(),
365 );
366 let router = Router::new().hoop(cache).goal(cached);
367 let service = Service::new(router);
368
369 let mut res = TestClient::get("http://127.0.0.1:5801")
370 .send(&service)
371 .await;
372 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
373
374 let content0 = res.take_string().await.unwrap();
375
376 let mut res = TestClient::get("http://127.0.0.1:5801")
377 .send(&service)
378 .await;
379 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
380
381 let content1 = res.take_string().await.unwrap();
382 assert_eq!(content0, content1);
383
384 tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
385 let mut res = TestClient::post("http://127.0.0.1:5801")
386 .send(&service)
387 .await;
388 let content2 = res.take_string().await.unwrap();
389
390 assert_ne!(content0, content2);
391 }
392}