tower_http_cache/
policy.rs

1use http::{HeaderMap, Method, StatusCode};
2use std::collections::HashSet;
3use std::fmt;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::logging::MLLoggingConfig;
8use crate::streaming::StreamingPolicy;
9use crate::tags::TagPolicy;
10
11/// Type alias for the method predicate function
12type MethodPredicateFn = Arc<dyn Fn(&Method) -> bool + Send + Sync>;
13
14/// Type alias for tag extractor function
15type TagExtractorFn = Arc<dyn Fn(&Method, &http::Uri) -> Vec<String> + Send + Sync>;
16
17/// Runtime cache policy shared by the layer and backend.
18///
19/// Policies define how long responses stay in cache, which headers are
20/// persisted, which status codes are cacheable, and much more. Policies are
21/// cheap to clone and are immutable—the `with_*` builder helpers return new
22/// copies with the requested change.
23#[derive(Clone)]
24pub struct CachePolicy {
25    ttl: Duration,
26    negative_ttl: Duration,
27    stale_while_revalidate: Duration,
28    refresh_before: Duration,
29    max_body_size: Option<usize>,
30    min_body_size: Option<usize>,
31    cache_statuses: HashSet<u16>,
32    respect_cache_control: bool,
33    method_predicate: Option<MethodPredicateFn>,
34    header_allowlist: Option<HashSet<String>>,
35    allow_streaming_bodies: bool,
36    compression: CompressionConfig,
37    ml_logging: MLLoggingConfig,
38    tag_policy: TagPolicy,
39    tag_extractor: Option<TagExtractorFn>,
40    streaming_policy: StreamingPolicy,
41}
42
43/// Strategy for compressing cached payloads.
44#[derive(Clone, Copy, Debug)]
45pub enum CompressionStrategy {
46    None,
47    Gzip,
48}
49
50/// Compression configuration attached to a [`CachePolicy`].
51#[derive(Clone, Copy, Debug)]
52pub struct CompressionConfig {
53    pub strategy: CompressionStrategy,
54    pub min_size: usize,
55}
56
57impl Default for CompressionConfig {
58    fn default() -> Self {
59        Self {
60            strategy: CompressionStrategy::None,
61            min_size: 1024,
62        }
63    }
64}
65
66impl fmt::Debug for CachePolicy {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        f.debug_struct("CachePolicy")
69            .field("ttl", &self.ttl)
70            .field("negative_ttl", &self.negative_ttl)
71            .field("stale_while_revalidate", &self.stale_while_revalidate)
72            .field("refresh_before", &self.refresh_before)
73            .field("max_body_size", &self.max_body_size)
74            .field("min_body_size", &self.min_body_size)
75            .field("cache_statuses", &self.cache_statuses)
76            .field("respect_cache_control", &self.respect_cache_control)
77            .field(
78                "header_allowlist",
79                &self
80                    .header_allowlist
81                    .as_ref()
82                    .map(|set| set.iter().collect::<Vec<_>>()),
83            )
84            .field("allow_streaming_bodies", &self.allow_streaming_bodies)
85            .field("compression", &self.compression)
86            .finish()
87    }
88}
89
90impl CachePolicy {
91    /// Builds a new policy with explicit TTL settings and cacheable statuses.
92    pub fn new(
93        ttl: Duration,
94        negative_ttl: Duration,
95        statuses: impl IntoIterator<Item = u16>,
96    ) -> Self {
97        Self {
98            ttl,
99            negative_ttl,
100            stale_while_revalidate: Duration::from_secs(0),
101            refresh_before: Duration::from_secs(0),
102            max_body_size: None,
103            min_body_size: None,
104            cache_statuses: statuses.into_iter().collect(),
105            respect_cache_control: true,
106            method_predicate: None,
107            header_allowlist: None,
108            allow_streaming_bodies: false,
109            compression: CompressionConfig::default(),
110            ml_logging: MLLoggingConfig::default(),
111            tag_policy: TagPolicy::default(),
112            tag_extractor: None,
113            streaming_policy: StreamingPolicy::default(),
114        }
115    }
116
117    /// Returns the TTL for the given HTTP status code if it should be cached.
118    pub fn ttl_for(&self, status: StatusCode) -> Option<Duration> {
119        if self.cache_statuses.contains(&status.as_u16()) {
120            Some(self.ttl)
121        } else if status.is_client_error() && !self.negative_ttl.is_zero() {
122            Some(self.negative_ttl)
123        } else {
124            None
125        }
126    }
127
128    /// Determines whether the request method is cacheable.
129    pub fn should_cache_method(&self, method: &Method) -> bool {
130        if let Some(predicate) = &self.method_predicate {
131            predicate(method)
132        } else {
133            matches!(method, &Method::GET | &Method::HEAD)
134        }
135    }
136
137    /// Returns whether `Cache-Control`/`Pragma` headers on requests are honored.
138    pub fn respect_cache_control(&self) -> bool {
139        self.respect_cache_control
140    }
141
142    pub fn max_body_size(&self) -> Option<usize> {
143        self.max_body_size
144    }
145
146    pub fn min_body_size(&self) -> Option<usize> {
147        self.min_body_size
148    }
149
150    pub fn allow_streaming_bodies(&self) -> bool {
151        self.allow_streaming_bodies
152    }
153
154    pub fn compression(&self) -> CompressionConfig {
155        self.compression
156    }
157
158    pub fn refresh_before(&self) -> Duration {
159        self.refresh_before
160    }
161
162    pub fn ttl(&self) -> Duration {
163        self.ttl
164    }
165
166    pub fn negative_ttl(&self) -> Duration {
167        self.negative_ttl
168    }
169
170    pub fn stale_while_revalidate(&self) -> Duration {
171        self.stale_while_revalidate
172    }
173
174    pub fn ml_logging(&self) -> &MLLoggingConfig {
175        &self.ml_logging
176    }
177
178    pub fn tag_policy(&self) -> &TagPolicy {
179        &self.tag_policy
180    }
181
182    pub fn streaming_policy(&self) -> &StreamingPolicy {
183        &self.streaming_policy
184    }
185
186    /// Extracts tags for a request using the configured tag extractor.
187    pub fn extract_tags(&self, method: &Method, uri: &http::Uri) -> Vec<String> {
188        if !self.tag_policy.enabled {
189            return Vec::new();
190        }
191
192        if let Some(ref extractor) = self.tag_extractor {
193            let tags = extractor(method, uri);
194            self.tag_policy.validate_tags(tags)
195        } else {
196            Vec::new()
197        }
198    }
199
200    /// Returns the headers that should be cached based on the allowlist.
201    pub fn headers_to_cache(&self, headers: &HeaderMap) -> Vec<(String, Vec<u8>)> {
202        match &self.header_allowlist {
203            Some(allowlist) => headers
204                .iter()
205                .filter(|(name, _)| allowlist.contains(&name.as_str().to_ascii_lowercase()))
206                .map(|(name, value)| (name.as_str().to_owned(), value.as_bytes().to_vec()))
207                .collect(),
208            None => headers
209                .iter()
210                .map(|(name, value)| (name.as_str().to_owned(), value.as_bytes().to_vec()))
211                .collect(),
212        }
213    }
214
215    pub fn with_ttl(mut self, ttl: Duration) -> Self {
216        self.ttl = ttl;
217        self
218    }
219
220    pub fn with_negative_ttl(mut self, ttl: Duration) -> Self {
221        self.negative_ttl = ttl;
222        self
223    }
224
225    pub fn with_statuses(mut self, statuses: impl IntoIterator<Item = u16>) -> Self {
226        self.cache_statuses = statuses.into_iter().collect();
227        self
228    }
229
230    pub fn with_stale_while_revalidate(mut self, duration: Duration) -> Self {
231        self.stale_while_revalidate = duration;
232        self
233    }
234
235    pub fn with_refresh_before(mut self, duration: Duration) -> Self {
236        self.refresh_before = duration;
237        self
238    }
239
240    pub fn with_max_body_size(mut self, size: Option<usize>) -> Self {
241        self.max_body_size = size;
242        self
243    }
244
245    pub fn with_min_body_size(mut self, size: Option<usize>) -> Self {
246        self.min_body_size = size;
247        self
248    }
249
250    pub fn with_allow_streaming_bodies(mut self, allow: bool) -> Self {
251        self.allow_streaming_bodies = allow;
252        self
253    }
254
255    pub fn with_compression(mut self, config: CompressionConfig) -> Self {
256        self.compression = config;
257        self
258    }
259
260    pub fn with_respect_cache_control(mut self, enabled: bool) -> Self {
261        self.respect_cache_control = enabled;
262        self
263    }
264
265    pub fn with_method_predicate<F>(mut self, predicate: F) -> Self
266    where
267        F: Fn(&Method) -> bool + Send + Sync + 'static,
268    {
269        self.method_predicate = Some(Arc::new(predicate));
270        self
271    }
272
273    pub fn with_header_allowlist<I, S>(mut self, headers: I) -> Self
274    where
275        I: IntoIterator<Item = S>,
276        S: Into<String>,
277    {
278        self.header_allowlist = Some(
279            headers
280                .into_iter()
281                .map(|h| h.into().to_ascii_lowercase())
282                .collect(),
283        );
284        self
285    }
286
287    pub fn with_ml_logging(mut self, config: MLLoggingConfig) -> Self {
288        self.ml_logging = config;
289        self
290    }
291
292    pub fn with_tag_policy(mut self, policy: TagPolicy) -> Self {
293        self.tag_policy = policy;
294        self
295    }
296
297    pub fn with_tag_extractor<F>(mut self, extractor: F) -> Self
298    where
299        F: Fn(&Method, &http::Uri) -> Vec<String> + Send + Sync + 'static,
300    {
301        self.tag_extractor = Some(Arc::new(extractor));
302        self
303    }
304
305    pub fn with_streaming_policy(mut self, policy: StreamingPolicy) -> Self {
306        self.streaming_policy = policy;
307        self
308    }
309}
310
311impl Default for CachePolicy {
312    fn default() -> Self {
313        Self {
314            ttl: Duration::from_secs(60),
315            negative_ttl: Duration::from_secs(5),
316            stale_while_revalidate: Duration::from_secs(0),
317            refresh_before: Duration::from_secs(0),
318            max_body_size: None,
319            min_body_size: None,
320            cache_statuses: HashSet::from([200, 203, 300, 301, 404]),
321            respect_cache_control: true,
322            method_predicate: None,
323            header_allowlist: None,
324            allow_streaming_bodies: false,
325            compression: CompressionConfig::default(),
326            ml_logging: MLLoggingConfig::default(),
327            tag_policy: TagPolicy::default(),
328            tag_extractor: None,
329            streaming_policy: StreamingPolicy::default(),
330        }
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
338    use std::collections::HashSet;
339
340    #[test]
341    fn ttl_for_prefers_primary_ttl_for_cacheable_status() {
342        let policy = CachePolicy::default().with_ttl(Duration::from_secs(123));
343        assert_eq!(
344            policy.ttl_for(StatusCode::OK),
345            Some(Duration::from_secs(123))
346        );
347    }
348
349    #[test]
350    fn ttl_for_uses_negative_ttl_for_client_error() {
351        let policy = CachePolicy::default().with_negative_ttl(Duration::from_secs(9));
352        assert_eq!(
353            policy.ttl_for(StatusCode::BAD_REQUEST),
354            Some(Duration::from_secs(9))
355        );
356        assert_eq!(policy.ttl_for(StatusCode::INTERNAL_SERVER_ERROR), None);
357    }
358
359    #[test]
360    fn headers_to_cache_respects_allowlist() {
361        let mut headers = HeaderMap::new();
362        headers.insert(
363            HeaderName::from_static("content-type"),
364            HeaderValue::from_static("text/plain"),
365        );
366        headers.insert(
367            HeaderName::from_static("x-cacheable"),
368            HeaderValue::from_static("yes"),
369        );
370
371        let policy = CachePolicy::default().with_header_allowlist(["content-type"]);
372        let cached = policy.headers_to_cache(&headers);
373
374        let expected = vec![("content-type".to_owned(), b"text/plain".to_vec())];
375        assert_eq!(cached, expected);
376    }
377
378    #[test]
379    fn method_predicate_overrides_default_behavior() {
380        let policy = CachePolicy::default().with_method_predicate(|method| method == Method::POST);
381        assert!(!policy.should_cache_method(&Method::GET));
382        assert!(policy.should_cache_method(&Method::POST));
383    }
384
385    #[test]
386    fn with_statuses_updates_allowlist() {
387        let policy = CachePolicy::default().with_statuses([201, 202]);
388        assert_eq!(policy.ttl_for(StatusCode::CREATED), Some(policy.ttl()));
389        assert_eq!(policy.ttl_for(StatusCode::OK), None);
390        assert_eq!(policy.cache_statuses, HashSet::from([201_u16, 202_u16]));
391    }
392}