1#![forbid(unsafe_code, future_incompatible)]
23#![deny(
24 missing_docs,
25 missing_debug_implementations,
26 missing_copy_implementations,
27 nonstandard_style,
28 unused_qualifications,
29 rustdoc::missing_doc_code_examples
30)]
31use std::{str::FromStr, time::SystemTime};
32
33use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
34use http_types::{
35 headers::{HeaderValue, CACHE_CONTROL},
36 Method,
37};
38use surf::{
39 middleware::{Middleware, Next},
40 Client, Request, Response,
41};
42
43pub mod managers;
45
46type Result<T> = std::result::Result<T, http_types::Error>;
47
48#[surf::utils::async_trait]
50pub trait CacheManager {
51 async fn get(&self, req: &Request) -> Result<Option<(Response, CachePolicy)>>;
53 async fn put(&self, req: &Request, res: &mut Response, policy: CachePolicy)
55 -> Result<Response>;
56 async fn delete(&self, req: &Request) -> Result<()>;
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum CacheMode {
64 Default,
72 NoStore,
74 Reload,
77 NoCache,
80 ForceCache,
84 OnlyIfCached,
90}
91
92#[derive(Debug, Clone)]
94pub struct Cache<T: CacheManager> {
95 pub mode: CacheMode,
97 pub cache_manager: T,
99}
100
101impl<T: CacheManager> Cache<T> {
102 pub async fn run(&self, mut req: Request, client: Client, next: Next<'_>) -> Result<Response> {
104 let is_cacheable = (req.method() == Method::Get || req.method() == Method::Head)
105 && self.mode != CacheMode::NoStore
106 && self.mode != CacheMode::Reload;
107
108 if !is_cacheable {
109 return self.remote_fetch(req, client, next).await;
110 }
111
112 if let Some(store) = self.cache_manager.get(&req).await? {
113 let (mut res, policy) = store;
114 if let Some(warning_code) = get_warning_code(&res) {
115 #[allow(clippy::manual_range_contains)]
126 if warning_code >= 100 && warning_code < 200 {
127 res.remove_header("Warning");
128 }
129 }
130
131 match self.mode {
132 CacheMode::Default => Ok(self
133 .conditional_fetch(req, res, policy, client, next)
134 .await?),
135 CacheMode::NoCache => {
136 req.insert_header(CACHE_CONTROL.as_str(), "no-cache");
137 Ok(self
138 .conditional_fetch(req, res, policy, client, next)
139 .await?)
140 }
141 CacheMode::ForceCache | CacheMode::OnlyIfCached => {
142 add_warning(&mut res, req.url(), 112, "Disconnected operation");
147 Ok(res)
148 }
149 _ => Ok(self.remote_fetch(req, client, next).await?),
150 }
151 } else {
152 match self.mode {
153 CacheMode::OnlyIfCached => {
154 let err_res = http_types::Response::new(http_types::StatusCode::GatewayTimeout);
156 Ok(err_res.into())
157 }
158 _ => Ok(self.remote_fetch(req, client, next).await?),
159 }
160 }
161 }
162
163 async fn conditional_fetch(
164 &self,
165 mut req: Request,
166 mut cached_res: Response,
167 mut policy: CachePolicy,
168 client: Client,
169 next: Next<'_>,
170 ) -> Result<Response> {
171 let before_req = policy.before_request(&get_request_parts(&req)?, SystemTime::now());
172 match before_req {
173 BeforeRequest::Fresh(parts) => {
174 update_response_headers(parts, &mut cached_res)?;
175 return Ok(cached_res);
176 }
177 BeforeRequest::Stale {
178 request: parts,
179 matches,
180 } => {
181 if matches {
182 update_request_headers(parts, &mut req)?;
183 }
184 }
185 }
186 let copied_req = req.clone();
187 match self.remote_fetch(req, client, next).await {
188 Ok(cond_res) => {
189 if cond_res.status().is_server_error() && must_revalidate(&cached_res) {
190 add_warning(
196 &mut cached_res,
197 copied_req.url(),
198 111,
199 "Revalidation failed",
200 );
201 Ok(cached_res)
202 } else if cond_res.status() == http_types::StatusCode::NotModified {
203 let mut res = http_types::Response::new(cond_res.status());
204 for (key, value) in cond_res.iter() {
205 res.append_header(key, value.clone().as_str());
206 }
207 res.set_body(cached_res.body_string().await?);
208 let mut converted = Response::from(res);
209 let after_res = policy.after_response(
210 &get_request_parts(&copied_req)?,
211 &get_response_parts(&cond_res)?,
212 SystemTime::now(),
213 );
214 match after_res {
215 AfterResponse::Modified(new_policy, parts) => {
216 policy = new_policy;
217 update_response_headers(parts, &mut converted)?;
218 }
219 AfterResponse::NotModified(new_policy, parts) => {
220 policy = new_policy;
221 update_response_headers(parts, &mut converted)?;
222 }
223 }
224 let res = self
225 .cache_manager
226 .put(&copied_req, &mut converted, policy)
227 .await?;
228 Ok(res)
229 } else {
230 Ok(cached_res)
231 }
232 }
233 Err(e) => {
234 if must_revalidate(&cached_res) {
235 Err(e)
236 } else {
237 add_warning(
243 &mut cached_res,
244 copied_req.url(),
245 111,
246 "Revalidation failed",
247 );
248 add_warning(
255 &mut cached_res,
256 copied_req.url(),
257 199,
258 format!("Miscellaneous Warning {}", e).as_str(),
259 );
260 Ok(cached_res)
261 }
262 }
263 }
264 }
265
266 async fn remote_fetch(&self, req: Request, client: Client, next: Next<'_>) -> Result<Response> {
267 let copied_req = req.clone();
268 let mut res = next.run(req, client).await?;
269 let is_method_get_head =
270 copied_req.method() == Method::Get || copied_req.method() == Method::Head;
271 let policy = CachePolicy::new(&get_request_parts(&copied_req)?, &get_response_parts(&res)?);
272 let is_cacheable = self.mode != CacheMode::NoStore
273 && is_method_get_head
274 && res.status() == http_types::StatusCode::Ok
275 && policy.is_storable();
276 if is_cacheable {
277 Ok(self
278 .cache_manager
279 .put(&copied_req, &mut res, policy)
280 .await?)
281 } else if !is_method_get_head {
282 self.cache_manager.delete(&copied_req).await?;
283 Ok(res)
284 } else {
285 Ok(res)
286 }
287 }
288}
289
290fn must_revalidate(res: &Response) -> bool {
291 if let Some(val) = res.header(CACHE_CONTROL.as_str()) {
292 val.as_str().to_lowercase().contains("must-revalidate")
293 } else {
294 false
295 }
296}
297
298fn get_warning_code(res: &Response) -> Option<usize> {
299 res.header("Warning").and_then(|hdr| {
300 hdr.as_str()
301 .chars()
302 .take(3)
303 .collect::<String>()
304 .parse()
305 .ok()
306 })
307}
308
309fn update_request_headers(parts: http::request::Parts, req: &mut Request) -> Result<()> {
310 for header in parts.headers.iter() {
311 req.set_header(
312 header.0.as_str(),
313 http_types::headers::HeaderValue::from_str(header.1.to_str()?)?,
314 );
315 }
316 Ok(())
317}
318
319fn update_response_headers(parts: http::response::Parts, res: &mut Response) -> Result<()> {
320 for header in parts.headers.iter() {
321 res.insert_header(
322 header.0.as_str(),
323 http_types::headers::HeaderValue::from_str(header.1.to_str()?)?,
324 );
325 }
326 Ok(())
327}
328
329fn get_response_parts(res: &Response) -> Result<http::response::Parts> {
331 let mut headers = http::HeaderMap::new();
332 for header in res.iter() {
333 headers.insert(
334 http::header::HeaderName::from_str(header.0.as_str())?,
335 http::HeaderValue::from_str(header.1.as_str())?,
336 );
337 }
338 let status = http::StatusCode::from_str(res.status().to_string().as_ref())?;
339 let mut converted = http::response::Response::new(());
340 converted.headers_mut().clone_from(&headers);
341 converted.status_mut().clone_from(&status);
342 let parts = converted.into_parts();
343 Ok(parts.0)
344}
345
346fn get_request_parts(req: &Request) -> Result<http::request::Parts> {
348 let mut headers = http::HeaderMap::new();
349 for header in req.iter() {
350 headers.insert(
351 http::header::HeaderName::from_str(header.0.as_str())?,
352 http::HeaderValue::from_str(header.1.as_str())?,
353 );
354 }
355 let uri = http::Uri::from_str(req.url().as_str())?;
356 let method = http::Method::from_str(req.method().as_ref())?;
357 let mut converted = http::request::Request::new(());
358 converted.headers_mut().clone_from(&headers);
359 converted.uri_mut().clone_from(&uri);
360 converted.method_mut().clone_from(&method);
361 let parts = converted.into_parts();
362 Ok(parts.0)
363}
364
365fn add_warning(res: &mut Response, uri: &surf::http::Url, code: usize, message: &str) {
366 let val = HeaderValue::from_str(
377 format!(
378 "{} {} {:?} \"{}\"",
379 code,
380 uri.host().expect("Invalid URL"),
381 message,
382 httpdate::fmt_http_date(SystemTime::now())
383 )
384 .as_str(),
385 )
386 .expect("Failed to generate warning string");
387 res.append_header("Warning", val);
388}
389
390#[surf::utils::async_trait]
391impl<T: CacheManager + 'static + Send + Sync> Middleware for Cache<T> {
392 async fn handle(&self, req: Request, client: Client, next: Next<'_>) -> Result<Response> {
393 let res = self.run(req, client, next).await?;
394 Ok(res)
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use http_types::{Response, StatusCode};
402 use surf::Result;
403
404 #[async_std::test]
405 async fn can_get_warning_code() -> Result<()> {
406 let url = surf::http::Url::from_str("https://example.com")?;
407 let mut res = surf::Response::from(Response::new(StatusCode::Ok));
408 add_warning(&mut res, &url, 111, "Revalidation failed");
409 let code = get_warning_code(&res).unwrap();
410 assert_eq!(code, 111);
411 Ok(())
412 }
413
414 #[async_std::test]
415 async fn can_check_revalidate() {
416 let mut res = Response::new(StatusCode::Ok);
417 res.append_header("Cache-Control", "max-age=1733992, must-revalidate");
418 let check = must_revalidate(&res.into());
419 assert!(check, "{}", true)
420 }
421}