#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![deny(unreachable_pub)]
#![forbid(unsafe_code)]
#![warn(missing_docs)]
use std::borrow::Borrow;
use std::collections::VecDeque;
use std::error::Error as StdError;
use std::hash::Hash;
use bytes::Bytes;
use salvo_core::handler::Skipper;
use salvo_core::http::{HeaderMap, ResBody, StatusCode};
use salvo_core::{async_trait, Depot, Error, FlowCtrl, Handler, Request, Response};
mod skipper;
pub use skipper::MethodSkipper;
#[macro_use]
mod cfg;
cfg_feature! {
#![feature = "moka-store"]
pub mod moka_store;
pub use moka_store::{MokaStore};
}
#[async_trait]
pub trait CacheIssuer: Send + Sync + 'static {
type Key: Hash + Eq + Send + Sync + 'static;
async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key>;
}
#[async_trait]
impl<F, K> CacheIssuer for F
where
F: Fn(&mut Request, &Depot) -> Option<K> + Send + Sync + 'static,
K: Hash + Eq + Send + Sync + 'static,
{
type Key = K;
async fn issue(&self, req: &mut Request, depot: &Depot) -> Option<Self::Key> {
(self)(req, depot)
}
}
#[derive(Clone, Debug)]
pub struct RequestIssuer {
use_scheme: bool,
use_authority: bool,
use_path: bool,
use_query: bool,
use_method: bool,
}
impl Default for RequestIssuer {
fn default() -> Self {
Self::new()
}
}
impl RequestIssuer {
pub fn new() -> Self {
Self {
use_scheme: true,
use_authority: true,
use_path: true,
use_query: true,
use_method: true,
}
}
pub fn use_scheme(mut self, value: bool) -> Self {
self.use_scheme = value;
self
}
pub fn use_authority(mut self, value: bool) -> Self {
self.use_authority = value;
self
}
pub fn use_path(mut self, value: bool) -> Self {
self.use_path = value;
self
}
pub fn use_query(mut self, value: bool) -> Self {
self.use_query = value;
self
}
pub fn use_method(mut self, value: bool) -> Self {
self.use_method = value;
self
}
}
#[async_trait]
impl CacheIssuer for RequestIssuer {
type Key = String;
async fn issue(&self, req: &mut Request, _depot: &Depot) -> Option<Self::Key> {
let mut key = String::new();
if self.use_scheme {
if let Some(scheme) = req.uri().scheme_str() {
key.push_str(scheme);
key.push_str("://");
}
}
if self.use_authority {
if let Some(authority) = req.uri().authority() {
key.push_str(authority.as_str());
}
}
if self.use_path {
key.push_str(req.uri().path());
}
if self.use_query {
if let Some(query) = req.uri().query() {
key.push('?');
key.push_str(query);
}
}
if self.use_method {
key.push('|');
key.push_str(req.method().as_str());
}
Some(key)
}
}
#[async_trait]
pub trait CacheStore: Send + Sync + 'static {
type Error: StdError + Sync + Send + 'static;
type Key: Hash + Eq + Send + Clone + 'static;
async fn load_entry<Q>(&self, key: &Q) -> Option<CachedEntry>
where
Self::Key: Borrow<Q>,
Q: Hash + Eq + Sync;
async fn save_entry(&self, key: Self::Key, data: CachedEntry) -> Result<(), Self::Error>;
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum CachedBody {
None,
Once(Bytes),
Chunks(VecDeque<Bytes>),
}
impl TryFrom<&ResBody> for CachedBody {
type Error = Error;
fn try_from(body: &ResBody) -> Result<Self, Self::Error> {
match body {
ResBody::None => Ok(Self::None),
ResBody::Once(bytes) => Ok(Self::Once(bytes.to_owned())),
ResBody::Chunks(chunks) => Ok(Self::Chunks(chunks.to_owned())),
_ => Err(Error::other("unsupported body type")),
}
}
}
impl From<CachedBody> for ResBody {
fn from(body: CachedBody) -> Self {
match body {
CachedBody::None => Self::None,
CachedBody::Once(bytes) => Self::Once(bytes),
CachedBody::Chunks(chunks) => Self::Chunks(chunks),
}
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct CachedEntry {
pub status: Option<StatusCode>,
pub headers: HeaderMap,
pub body: CachedBody,
}
impl CachedEntry {
pub fn new(status: Option<StatusCode>, headers: HeaderMap, body: CachedBody) -> Self {
Self { status, headers, body }
}
pub fn status(&self) -> Option<StatusCode> {
self.status
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn body(&self) -> &CachedBody {
&self.body
}
}
#[non_exhaustive]
pub struct Cache<S, I> {
pub store: S,
pub issuer: I,
pub skipper: Box<dyn Skipper>,
}
impl<S, I> Cache<S, I> {
#[inline]
pub fn new(store: S, issuer: I) -> Self {
let skipper = MethodSkipper::new().skip_all().skip_get(false);
Cache {
store,
issuer,
skipper: Box::new(skipper),
}
}
#[inline]
pub fn skipper(mut self, skipper: impl Skipper) -> Self {
self.skipper = Box::new(skipper);
self
}
}
#[async_trait]
impl<S, I> Handler for Cache<S, I>
where
S: CacheStore<Key = I::Key>,
I: CacheIssuer,
{
async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
if self.skipper.skipped(req, depot) {
return;
}
let key = match self.issuer.issue(req, depot).await {
Some(key) => key,
None => {
return;
}
};
let cache = match self.store.load_entry(&key).await {
Some(cache) => cache,
None => {
ctrl.call_next(req, depot, res).await;
if !res.body.is_stream() && !res.body.is_error() {
let headers = res.headers().clone();
let body = TryInto::<CachedBody>::try_into(&res.body);
match body {
Ok(body) => {
let cached_data = CachedEntry::new(res.status_code, headers, body);
if let Err(e) = self.store.save_entry(key, cached_data).await {
tracing::error!(error = ?e, "cache failed");
}
}
Err(e) => tracing::error!(error = ?e, "cache failed"),
}
}
return;
}
};
let CachedEntry { status, headers, body } = cache;
if let Some(status) = status {
res.status_code(status);
}
*res.headers_mut() = headers;
*res.body_mut() = body.into();
ctrl.skip_rest();
}
}
#[cfg(test)]
mod tests {
use super::*;
use salvo_core::prelude::*;
use salvo_core::test::{ResponseExt, TestClient};
use time::OffsetDateTime;
#[handler]
async fn cached() -> String {
format!("Hello World, my birth time is {}", OffsetDateTime::now_utc())
}
#[tokio::test]
async fn test_cache() {
let cache = Cache::new(
MokaStore::builder()
.time_to_live(std::time::Duration::from_secs(5))
.build(),
RequestIssuer::default(),
);
let router = Router::new().hoop(cache).goal(cached);
let service = Service::new(router);
let mut res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
assert_eq!(res.status_code.unwrap(), StatusCode::OK);
let content0 = res.take_string().await.unwrap();
let mut res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
assert_eq!(res.status_code.unwrap(), StatusCode::OK);
let content1 = res.take_string().await.unwrap();
assert_eq!(content0, content1);
tokio::time::sleep(tokio::time::Duration::from_secs(6)).await;
let mut res = TestClient::post("http://127.0.0.1:5801").send(&service).await;
let content2 = res.take_string().await.unwrap();
assert_ne!(content0, content2);
}
}