surf_middleware_cache/
lib.rs

1//! A caching middleware for Surf that follows HTTP caching rules.
2//! By default it uses [`cacache`](https://github.com/zkat/cacache-rs) as the backend cache manager.
3//!
4//! ## Example
5//!
6//! ```no_run
7//! use surf_middleware_cache::{managers::CACacheManager, Cache, CacheMode};
8//!
9//! #[async_std::main]
10//! async fn main() -> surf::Result<()> {
11//!     let req = surf::get("https://developer.mozilla.org/en-US/docs/Web/HTTP/Caching");
12//!     surf::client()
13//!         .with(Cache {
14//!             mode: CacheMode::Default,
15//!             cache_manager: CACacheManager::default(),
16//!         })
17//!         .send(req)
18//!         .await?;
19//!     Ok(())
20//! }
21//! ```
22#![forbid(unsafe_code, future_incompatible)]
23#![deny(
24    missing_docs,
25    missing_debug_implementations,
26    missing_copy_implementations,
27    nonstandard_style,
28    unused_qualifications,
29    rustdoc::missing_doc_code_examples
30)]
31use std::{str::FromStr, time::SystemTime};
32
33use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
34use http_types::{
35    headers::{HeaderValue, CACHE_CONTROL},
36    Method,
37};
38use surf::{
39    middleware::{Middleware, Next},
40    Client, Request, Response,
41};
42
43/// Backend cache managers, cacache is the default.
44pub mod managers;
45
46type Result<T> = std::result::Result<T, http_types::Error>;
47
48/// A trait providing methods for storing, reading, and removing cache records.
49#[surf::utils::async_trait]
50pub trait CacheManager {
51    /// Attempts to pull a cached reponse and related policy from cache.
52    async fn get(&self, req: &Request) -> Result<Option<(Response, CachePolicy)>>;
53    /// Attempts to cache a response and related policy.
54    async fn put(&self, req: &Request, res: &mut Response, policy: CachePolicy)
55        -> Result<Response>;
56    /// Attempts to remove a record from cache.
57    async fn delete(&self, req: &Request) -> Result<()>;
58}
59
60/// Similar to [make-fetch-happen cache options](https://github.com/npm/make-fetch-happen#--optscache).
61/// Passed in when the [`Cache`] struct is being built.
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum CacheMode {
64    /// Will inspect the HTTP cache on the way to the network.
65    /// If there is a fresh response it will be used.
66    /// If there is a stale response a conditional request will be created,
67    /// and a normal request otherwise.
68    /// It then updates the HTTP cache with the response.
69    /// If the revalidation request fails (for example, on a 500 or if you're offline),
70    /// the stale response will be returned.
71    Default,
72    /// Behaves as if there is no HTTP cache at all.
73    NoStore,
74    /// Behaves as if there is no HTTP cache on the way to the network.
75    /// Ergo, it creates a normal request and updates the HTTP cache with the response.
76    Reload,
77    /// Creates a conditional request if there is a response in the HTTP cache
78    /// and a normal request otherwise. It then updates the HTTP cache with the response.
79    NoCache,
80    /// Uses any response in the HTTP cache matching the request,
81    /// not paying attention to staleness. If there was no response,
82    /// it creates a normal request and updates the HTTP cache with the response.
83    ForceCache,
84    /// Uses any response in the HTTP cache matching the request,
85    /// not paying attention to staleness. If there was no response,
86    /// it returns a network error. (Can only be used when request’s mode is "same-origin".
87    /// Any cached redirects will be followed assuming request’s redirect mode is "follow"
88    /// and the redirects do not violate request’s mode.)
89    OnlyIfCached,
90}
91
92/// Caches requests according to http spec
93#[derive(Debug, Clone)]
94pub struct Cache<T: CacheManager> {
95    /// Determines the manager behavior
96    pub mode: CacheMode,
97    /// Manager instance that implements the CacheManager trait
98    pub cache_manager: T,
99}
100
101impl<T: CacheManager> Cache<T> {
102    /// Called by the Surf middleware handle method when a request is made.
103    pub async fn run(&self, mut req: Request, client: Client, next: Next<'_>) -> Result<Response> {
104        let is_cacheable = (req.method() == Method::Get || req.method() == Method::Head)
105            && self.mode != CacheMode::NoStore
106            && self.mode != CacheMode::Reload;
107
108        if !is_cacheable {
109            return self.remote_fetch(req, client, next).await;
110        }
111
112        if let Some(store) = self.cache_manager.get(&req).await? {
113            let (mut res, policy) = store;
114            if let Some(warning_code) = get_warning_code(&res) {
115                // https://tools.ietf.org/html/rfc7234#section-4.3.4
116                //
117                // If a stored response is selected for update, the cache MUST:
118                //
119                // * delete any Warning header fields in the stored response with
120                //   warn-code 1xx (see Section 5.5);
121                //
122                // * retain any Warning header fields in the stored response with
123                //   warn-code 2xx;
124                //
125                #[allow(clippy::manual_range_contains)]
126                if warning_code >= 100 && warning_code < 200 {
127                    res.remove_header("Warning");
128                }
129            }
130
131            match self.mode {
132                CacheMode::Default => Ok(self
133                    .conditional_fetch(req, res, policy, client, next)
134                    .await?),
135                CacheMode::NoCache => {
136                    req.insert_header(CACHE_CONTROL.as_str(), "no-cache");
137                    Ok(self
138                        .conditional_fetch(req, res, policy, client, next)
139                        .await?)
140                }
141                CacheMode::ForceCache | CacheMode::OnlyIfCached => {
142                    //   112 Disconnected operation
143                    // SHOULD be included if the cache is intentionally disconnected from
144                    // the rest of the network for a period of time.
145                    // (https://tools.ietf.org/html/rfc2616#section-14.46)
146                    add_warning(&mut res, req.url(), 112, "Disconnected operation");
147                    Ok(res)
148                }
149                _ => Ok(self.remote_fetch(req, client, next).await?),
150            }
151        } else {
152            match self.mode {
153                CacheMode::OnlyIfCached => {
154                    // ENOTCACHED
155                    let err_res = http_types::Response::new(http_types::StatusCode::GatewayTimeout);
156                    Ok(err_res.into())
157                }
158                _ => Ok(self.remote_fetch(req, client, next).await?),
159            }
160        }
161    }
162
163    async fn conditional_fetch(
164        &self,
165        mut req: Request,
166        mut cached_res: Response,
167        mut policy: CachePolicy,
168        client: Client,
169        next: Next<'_>,
170    ) -> Result<Response> {
171        let before_req = policy.before_request(&get_request_parts(&req)?, SystemTime::now());
172        match before_req {
173            BeforeRequest::Fresh(parts) => {
174                update_response_headers(parts, &mut cached_res)?;
175                return Ok(cached_res);
176            }
177            BeforeRequest::Stale {
178                request: parts,
179                matches,
180            } => {
181                if matches {
182                    update_request_headers(parts, &mut req)?;
183                }
184            }
185        }
186        let copied_req = req.clone();
187        match self.remote_fetch(req, client, next).await {
188            Ok(cond_res) => {
189                if cond_res.status().is_server_error() && must_revalidate(&cached_res) {
190                    //   111 Revalidation failed
191                    //   MUST be included if a cache returns a stale response
192                    //   because an attempt to revalidate the response failed,
193                    //   due to an inability to reach the server.
194                    // (https://tools.ietf.org/html/rfc2616#section-14.46)
195                    add_warning(
196                        &mut cached_res,
197                        copied_req.url(),
198                        111,
199                        "Revalidation failed",
200                    );
201                    Ok(cached_res)
202                } else if cond_res.status() == http_types::StatusCode::NotModified {
203                    let mut res = http_types::Response::new(cond_res.status());
204                    for (key, value) in cond_res.iter() {
205                        res.append_header(key, value.clone().as_str());
206                    }
207                    res.set_body(cached_res.body_string().await?);
208                    let mut converted = Response::from(res);
209                    let after_res = policy.after_response(
210                        &get_request_parts(&copied_req)?,
211                        &get_response_parts(&cond_res)?,
212                        SystemTime::now(),
213                    );
214                    match after_res {
215                        AfterResponse::Modified(new_policy, parts) => {
216                            policy = new_policy;
217                            update_response_headers(parts, &mut converted)?;
218                        }
219                        AfterResponse::NotModified(new_policy, parts) => {
220                            policy = new_policy;
221                            update_response_headers(parts, &mut converted)?;
222                        }
223                    }
224                    let res = self
225                        .cache_manager
226                        .put(&copied_req, &mut converted, policy)
227                        .await?;
228                    Ok(res)
229                } else {
230                    Ok(cached_res)
231                }
232            }
233            Err(e) => {
234                if must_revalidate(&cached_res) {
235                    Err(e)
236                } else {
237                    //   111 Revalidation failed
238                    //   MUST be included if a cache returns a stale response
239                    //   because an attempt to revalidate the response failed,
240                    //   due to an inability to reach the server.
241                    // (https://tools.ietf.org/html/rfc2616#section-14.46)
242                    add_warning(
243                        &mut cached_res,
244                        copied_req.url(),
245                        111,
246                        "Revalidation failed",
247                    );
248                    //   199 Miscellaneous warning
249                    //   The warning text MAY include arbitrary information to
250                    //   be presented to a human user, or logged. A system
251                    //   receiving this warning MUST NOT take any automated
252                    //   action, besides presenting the warning to the user.
253                    // (https://tools.ietf.org/html/rfc2616#section-14.46)
254                    add_warning(
255                        &mut cached_res,
256                        copied_req.url(),
257                        199,
258                        format!("Miscellaneous Warning {}", e).as_str(),
259                    );
260                    Ok(cached_res)
261                }
262            }
263        }
264    }
265
266    async fn remote_fetch(&self, req: Request, client: Client, next: Next<'_>) -> Result<Response> {
267        let copied_req = req.clone();
268        let mut res = next.run(req, client).await?;
269        let is_method_get_head =
270            copied_req.method() == Method::Get || copied_req.method() == Method::Head;
271        let policy = CachePolicy::new(&get_request_parts(&copied_req)?, &get_response_parts(&res)?);
272        let is_cacheable = self.mode != CacheMode::NoStore
273            && is_method_get_head
274            && res.status() == http_types::StatusCode::Ok
275            && policy.is_storable();
276        if is_cacheable {
277            Ok(self
278                .cache_manager
279                .put(&copied_req, &mut res, policy)
280                .await?)
281        } else if !is_method_get_head {
282            self.cache_manager.delete(&copied_req).await?;
283            Ok(res)
284        } else {
285            Ok(res)
286        }
287    }
288}
289
290fn must_revalidate(res: &Response) -> bool {
291    if let Some(val) = res.header(CACHE_CONTROL.as_str()) {
292        val.as_str().to_lowercase().contains("must-revalidate")
293    } else {
294        false
295    }
296}
297
298fn get_warning_code(res: &Response) -> Option<usize> {
299    res.header("Warning").and_then(|hdr| {
300        hdr.as_str()
301            .chars()
302            .take(3)
303            .collect::<String>()
304            .parse()
305            .ok()
306    })
307}
308
309fn update_request_headers(parts: http::request::Parts, req: &mut Request) -> Result<()> {
310    for header in parts.headers.iter() {
311        req.set_header(
312            header.0.as_str(),
313            http_types::headers::HeaderValue::from_str(header.1.to_str()?)?,
314        );
315    }
316    Ok(())
317}
318
319fn update_response_headers(parts: http::response::Parts, res: &mut Response) -> Result<()> {
320    for header in parts.headers.iter() {
321        res.insert_header(
322            header.0.as_str(),
323            http_types::headers::HeaderValue::from_str(header.1.to_str()?)?,
324        );
325    }
326    Ok(())
327}
328
329// Convert the surf::Response for CachePolicy to use
330fn get_response_parts(res: &Response) -> Result<http::response::Parts> {
331    let mut headers = http::HeaderMap::new();
332    for header in res.iter() {
333        headers.insert(
334            http::header::HeaderName::from_str(header.0.as_str())?,
335            http::HeaderValue::from_str(header.1.as_str())?,
336        );
337    }
338    let status = http::StatusCode::from_str(res.status().to_string().as_ref())?;
339    let mut converted = http::response::Response::new(());
340    converted.headers_mut().clone_from(&headers);
341    converted.status_mut().clone_from(&status);
342    let parts = converted.into_parts();
343    Ok(parts.0)
344}
345
346// Convert the surf::Request for CachePolicy to use
347fn get_request_parts(req: &Request) -> Result<http::request::Parts> {
348    let mut headers = http::HeaderMap::new();
349    for header in req.iter() {
350        headers.insert(
351            http::header::HeaderName::from_str(header.0.as_str())?,
352            http::HeaderValue::from_str(header.1.as_str())?,
353        );
354    }
355    let uri = http::Uri::from_str(req.url().as_str())?;
356    let method = http::Method::from_str(req.method().as_ref())?;
357    let mut converted = http::request::Request::new(());
358    converted.headers_mut().clone_from(&headers);
359    converted.uri_mut().clone_from(&uri);
360    converted.method_mut().clone_from(&method);
361    let parts = converted.into_parts();
362    Ok(parts.0)
363}
364
365fn add_warning(res: &mut Response, uri: &surf::http::Url, code: usize, message: &str) {
366    //   Warning    = "Warning" ":" 1#warning-value
367    // warning-value = warn-code SP warn-agent SP warn-text [SP warn-date]
368    // warn-code  = 3DIGIT
369    // warn-agent = ( host [ ":" port ] ) | pseudonym
370    //                 ; the name or pseudonym of the server adding
371    //                 ; the Warning header, for use in debugging
372    // warn-text  = quoted-string
373    // warn-date  = <"> HTTP-date <">
374    // (https://tools.ietf.org/html/rfc2616#section-14.46)
375    //
376    let val = HeaderValue::from_str(
377        format!(
378            "{} {} {:?} \"{}\"",
379            code,
380            uri.host().expect("Invalid URL"),
381            message,
382            httpdate::fmt_http_date(SystemTime::now())
383        )
384        .as_str(),
385    )
386    .expect("Failed to generate warning string");
387    res.append_header("Warning", val);
388}
389
390#[surf::utils::async_trait]
391impl<T: CacheManager + 'static + Send + Sync> Middleware for Cache<T> {
392    async fn handle(&self, req: Request, client: Client, next: Next<'_>) -> Result<Response> {
393        let res = self.run(req, client, next).await?;
394        Ok(res)
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use http_types::{Response, StatusCode};
402    use surf::Result;
403
404    #[async_std::test]
405    async fn can_get_warning_code() -> Result<()> {
406        let url = surf::http::Url::from_str("https://example.com")?;
407        let mut res = surf::Response::from(Response::new(StatusCode::Ok));
408        add_warning(&mut res, &url, 111, "Revalidation failed");
409        let code = get_warning_code(&res).unwrap();
410        assert_eq!(code, 111);
411        Ok(())
412    }
413
414    #[async_std::test]
415    async fn can_check_revalidate() {
416        let mut res = Response::new(StatusCode::Ok);
417        res.append_header("Cache-Control", "max-age=1733992, must-revalidate");
418        let check = must_revalidate(&res.into());
419        assert!(check, "{}", true)
420    }
421}