tower_http_cache/backend/
mod.rs1#[cfg(feature = "memcached-backend")]
15pub mod memcached;
16pub mod memory;
17pub mod multi_tier;
18#[cfg(feature = "redis-backend")]
19pub mod redis;
20
21use async_trait::async_trait;
22use bytes::Bytes;
23use http::{HeaderName, HeaderValue, Response, StatusCode, Version};
24use std::time::{Duration, SystemTime};
25
26use crate::error::CacheError;
27use crate::layer::SyncBoxBody;
28
29#[derive(Debug, Clone)]
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
32pub struct CacheEntry {
33 #[cfg_attr(feature = "serde", serde(with = "status_code_serde"))]
34 pub status: StatusCode,
35 #[cfg_attr(feature = "serde", serde(with = "version_serde"))]
36 pub version: Version,
37 pub headers: Vec<(String, Vec<u8>)>,
38 #[cfg_attr(feature = "serde", serde(with = "bytes_serde"))]
39 pub body: Bytes,
40 pub tags: Option<Vec<String>>,
41}
42
43#[cfg(feature = "serde")]
45mod status_code_serde {
46 use http::StatusCode;
47 use serde::{Deserialize, Deserializer, Serialize, Serializer};
48
49 pub fn serialize<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
50 where
51 S: Serializer,
52 {
53 status.as_u16().serialize(serializer)
54 }
55
56 pub fn deserialize<'de, D>(deserializer: D) -> Result<StatusCode, D::Error>
57 where
58 D: Deserializer<'de>,
59 {
60 let code = u16::deserialize(deserializer)?;
61 StatusCode::from_u16(code).map_err(serde::de::Error::custom)
62 }
63}
64
65#[cfg(feature = "serde")]
66mod version_serde {
67 use http::Version;
68 use serde::{Deserialize, Deserializer, Serialize, Serializer};
69
70 pub fn serialize<S>(version: &Version, serializer: S) -> Result<S::Ok, S::Error>
71 where
72 S: Serializer,
73 {
74 let v = match *version {
75 Version::HTTP_09 => 0,
76 Version::HTTP_10 => 1,
77 Version::HTTP_11 => 2,
78 Version::HTTP_2 => 3,
79 Version::HTTP_3 => 4,
80 _ => 5,
81 };
82 v.serialize(serializer)
83 }
84
85 pub fn deserialize<'de, D>(deserializer: D) -> Result<Version, D::Error>
86 where
87 D: Deserializer<'de>,
88 {
89 let v = u8::deserialize(deserializer)?;
90 Ok(match v {
91 0 => Version::HTTP_09,
92 1 => Version::HTTP_10,
93 2 => Version::HTTP_11,
94 3 => Version::HTTP_2,
95 4 => Version::HTTP_3,
96 _ => Version::HTTP_11, })
98 }
99}
100
101#[cfg(feature = "serde")]
102mod bytes_serde {
103 use bytes::Bytes;
104 use serde::{Deserialize, Deserializer, Serializer};
105
106 pub fn serialize<S>(bytes: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
107 where
108 S: Serializer,
109 {
110 serializer.serialize_bytes(bytes)
111 }
112
113 pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
114 where
115 D: Deserializer<'de>,
116 {
117 let vec = Vec::<u8>::deserialize(deserializer)?;
118 Ok(Bytes::from(vec))
119 }
120}
121
122impl CacheEntry {
123 pub fn new(
128 status: StatusCode,
129 version: Version,
130 headers: Vec<(String, Vec<u8>)>,
131 body: Bytes,
132 ) -> Self {
133 Self {
134 status,
135 version,
136 headers,
137 body,
138 tags: None,
139 }
140 }
141
142 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
144 self.tags = Some(tags);
145 self
146 }
147
148 pub fn into_response(self) -> Response<SyncBoxBody> {
150 use http_body_util::BodyExt;
151
152 let full_body = http_body_util::Full::from(self.body);
153 let boxed_body = full_body
154 .map_err(|never| -> Box<dyn std::error::Error + Send + Sync> { match never {} })
155 .boxed();
156
157 let mut response = Response::new(SyncBoxBody::new(boxed_body));
158 *response.status_mut() = self.status;
159 *response.version_mut() = self.version;
160
161 let headers = response.headers_mut();
162 headers.clear();
163 for (name, value) in self.headers {
164 if let (Ok(name), Ok(value)) = (
165 HeaderName::from_bytes(name.as_bytes()),
166 HeaderValue::from_bytes(&value),
167 ) {
168 headers.append(name, value);
169 }
170 }
171
172 response
173 }
174}
175
176#[derive(Debug, Clone)]
177pub struct CacheRead {
178 pub entry: CacheEntry,
180 pub expires_at: Option<SystemTime>,
181 pub stale_until: Option<SystemTime>,
182}
183
184#[async_trait]
185pub trait CacheBackend: Send + Sync + Clone + 'static {
186 async fn get(&self, key: &str) -> Result<Option<CacheRead>, CacheError>;
191
192 async fn set(
194 &self,
195 key: String,
196 entry: CacheEntry,
197 ttl: Duration,
198 stale_for: Duration,
199 ) -> Result<(), CacheError>;
200
201 async fn invalidate(&self, key: &str) -> Result<(), CacheError>;
203
204 async fn get_keys_by_tag(&self, _tag: &str) -> Result<Vec<String>, CacheError> {
208 Ok(Vec::new())
209 }
210
211 async fn invalidate_by_tag(&self, tag: &str) -> Result<usize, CacheError> {
215 let keys = self.get_keys_by_tag(tag).await?;
216 let count = keys.len();
217 for key in keys {
218 let _ = self.invalidate(&key).await;
219 }
220 Ok(count)
221 }
222
223 async fn invalidate_by_tags(&self, tags: &[String]) -> Result<usize, CacheError> {
227 let mut total = 0;
228 for tag in tags {
229 total += self.invalidate_by_tag(tag).await?;
230 }
231 Ok(total)
232 }
233
234 async fn list_tags(&self) -> Result<Vec<String>, CacheError> {
238 Ok(Vec::new())
239 }
240}