1use http::{HeaderMap, Method, StatusCode};
2use std::collections::HashSet;
3use std::fmt;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::logging::MLLoggingConfig;
8use crate::streaming::StreamingPolicy;
9use crate::tags::TagPolicy;
10
11type MethodPredicateFn = Arc<dyn Fn(&Method) -> bool + Send + Sync>;
13
14type TagExtractorFn = Arc<dyn Fn(&Method, &http::Uri) -> Vec<String> + Send + Sync>;
16
17#[derive(Clone)]
24pub struct CachePolicy {
25 ttl: Duration,
26 negative_ttl: Duration,
27 stale_while_revalidate: Duration,
28 refresh_before: Duration,
29 max_body_size: Option<usize>,
30 min_body_size: Option<usize>,
31 cache_statuses: HashSet<u16>,
32 respect_cache_control: bool,
33 method_predicate: Option<MethodPredicateFn>,
34 header_allowlist: Option<HashSet<String>>,
35 allow_streaming_bodies: bool,
36 compression: CompressionConfig,
37 ml_logging: MLLoggingConfig,
38 tag_policy: TagPolicy,
39 tag_extractor: Option<TagExtractorFn>,
40 streaming_policy: StreamingPolicy,
41}
42
43#[derive(Clone, Copy, Debug)]
45pub enum CompressionStrategy {
46 None,
47 Gzip,
48}
49
50#[derive(Clone, Copy, Debug)]
52pub struct CompressionConfig {
53 pub strategy: CompressionStrategy,
54 pub min_size: usize,
55}
56
57impl Default for CompressionConfig {
58 fn default() -> Self {
59 Self {
60 strategy: CompressionStrategy::None,
61 min_size: 1024,
62 }
63 }
64}
65
66impl fmt::Debug for CachePolicy {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 f.debug_struct("CachePolicy")
69 .field("ttl", &self.ttl)
70 .field("negative_ttl", &self.negative_ttl)
71 .field("stale_while_revalidate", &self.stale_while_revalidate)
72 .field("refresh_before", &self.refresh_before)
73 .field("max_body_size", &self.max_body_size)
74 .field("min_body_size", &self.min_body_size)
75 .field("cache_statuses", &self.cache_statuses)
76 .field("respect_cache_control", &self.respect_cache_control)
77 .field(
78 "header_allowlist",
79 &self
80 .header_allowlist
81 .as_ref()
82 .map(|set| set.iter().collect::<Vec<_>>()),
83 )
84 .field("allow_streaming_bodies", &self.allow_streaming_bodies)
85 .field("compression", &self.compression)
86 .finish()
87 }
88}
89
90impl CachePolicy {
91 pub fn new(
93 ttl: Duration,
94 negative_ttl: Duration,
95 statuses: impl IntoIterator<Item = u16>,
96 ) -> Self {
97 Self {
98 ttl,
99 negative_ttl,
100 stale_while_revalidate: Duration::from_secs(0),
101 refresh_before: Duration::from_secs(0),
102 max_body_size: None,
103 min_body_size: None,
104 cache_statuses: statuses.into_iter().collect(),
105 respect_cache_control: true,
106 method_predicate: None,
107 header_allowlist: None,
108 allow_streaming_bodies: false,
109 compression: CompressionConfig::default(),
110 ml_logging: MLLoggingConfig::default(),
111 tag_policy: TagPolicy::default(),
112 tag_extractor: None,
113 streaming_policy: StreamingPolicy::default(),
114 }
115 }
116
117 pub fn ttl_for(&self, status: StatusCode) -> Option<Duration> {
119 if self.cache_statuses.contains(&status.as_u16()) {
120 Some(self.ttl)
121 } else if status.is_client_error() && !self.negative_ttl.is_zero() {
122 Some(self.negative_ttl)
123 } else {
124 None
125 }
126 }
127
128 pub fn should_cache_method(&self, method: &Method) -> bool {
130 if let Some(predicate) = &self.method_predicate {
131 predicate(method)
132 } else {
133 matches!(method, &Method::GET | &Method::HEAD)
134 }
135 }
136
137 pub fn respect_cache_control(&self) -> bool {
139 self.respect_cache_control
140 }
141
142 pub fn max_body_size(&self) -> Option<usize> {
143 self.max_body_size
144 }
145
146 pub fn min_body_size(&self) -> Option<usize> {
147 self.min_body_size
148 }
149
150 pub fn allow_streaming_bodies(&self) -> bool {
151 self.allow_streaming_bodies
152 }
153
154 pub fn compression(&self) -> CompressionConfig {
155 self.compression
156 }
157
158 pub fn refresh_before(&self) -> Duration {
159 self.refresh_before
160 }
161
162 pub fn ttl(&self) -> Duration {
163 self.ttl
164 }
165
166 pub fn negative_ttl(&self) -> Duration {
167 self.negative_ttl
168 }
169
170 pub fn stale_while_revalidate(&self) -> Duration {
171 self.stale_while_revalidate
172 }
173
174 pub fn ml_logging(&self) -> &MLLoggingConfig {
175 &self.ml_logging
176 }
177
178 pub fn tag_policy(&self) -> &TagPolicy {
179 &self.tag_policy
180 }
181
182 pub fn streaming_policy(&self) -> &StreamingPolicy {
183 &self.streaming_policy
184 }
185
186 pub fn extract_tags(&self, method: &Method, uri: &http::Uri) -> Vec<String> {
188 if !self.tag_policy.enabled {
189 return Vec::new();
190 }
191
192 if let Some(ref extractor) = self.tag_extractor {
193 let tags = extractor(method, uri);
194 self.tag_policy.validate_tags(tags)
195 } else {
196 Vec::new()
197 }
198 }
199
200 pub fn headers_to_cache(&self, headers: &HeaderMap) -> Vec<(String, Vec<u8>)> {
202 match &self.header_allowlist {
203 Some(allowlist) => headers
204 .iter()
205 .filter(|(name, _)| allowlist.contains(&name.as_str().to_ascii_lowercase()))
206 .map(|(name, value)| (name.as_str().to_owned(), value.as_bytes().to_vec()))
207 .collect(),
208 None => headers
209 .iter()
210 .map(|(name, value)| (name.as_str().to_owned(), value.as_bytes().to_vec()))
211 .collect(),
212 }
213 }
214
215 pub fn with_ttl(mut self, ttl: Duration) -> Self {
216 self.ttl = ttl;
217 self
218 }
219
220 pub fn with_negative_ttl(mut self, ttl: Duration) -> Self {
221 self.negative_ttl = ttl;
222 self
223 }
224
225 pub fn with_statuses(mut self, statuses: impl IntoIterator<Item = u16>) -> Self {
226 self.cache_statuses = statuses.into_iter().collect();
227 self
228 }
229
230 pub fn with_stale_while_revalidate(mut self, duration: Duration) -> Self {
231 self.stale_while_revalidate = duration;
232 self
233 }
234
235 pub fn with_refresh_before(mut self, duration: Duration) -> Self {
236 self.refresh_before = duration;
237 self
238 }
239
240 pub fn with_max_body_size(mut self, size: Option<usize>) -> Self {
241 self.max_body_size = size;
242 self
243 }
244
245 pub fn with_min_body_size(mut self, size: Option<usize>) -> Self {
246 self.min_body_size = size;
247 self
248 }
249
250 pub fn with_allow_streaming_bodies(mut self, allow: bool) -> Self {
251 self.allow_streaming_bodies = allow;
252 self
253 }
254
255 pub fn with_compression(mut self, config: CompressionConfig) -> Self {
256 self.compression = config;
257 self
258 }
259
260 pub fn with_respect_cache_control(mut self, enabled: bool) -> Self {
261 self.respect_cache_control = enabled;
262 self
263 }
264
265 pub fn with_method_predicate<F>(mut self, predicate: F) -> Self
266 where
267 F: Fn(&Method) -> bool + Send + Sync + 'static,
268 {
269 self.method_predicate = Some(Arc::new(predicate));
270 self
271 }
272
273 pub fn with_header_allowlist<I, S>(mut self, headers: I) -> Self
274 where
275 I: IntoIterator<Item = S>,
276 S: Into<String>,
277 {
278 self.header_allowlist = Some(
279 headers
280 .into_iter()
281 .map(|h| h.into().to_ascii_lowercase())
282 .collect(),
283 );
284 self
285 }
286
287 pub fn with_ml_logging(mut self, config: MLLoggingConfig) -> Self {
288 self.ml_logging = config;
289 self
290 }
291
292 pub fn with_tag_policy(mut self, policy: TagPolicy) -> Self {
293 self.tag_policy = policy;
294 self
295 }
296
297 pub fn with_tag_extractor<F>(mut self, extractor: F) -> Self
298 where
299 F: Fn(&Method, &http::Uri) -> Vec<String> + Send + Sync + 'static,
300 {
301 self.tag_extractor = Some(Arc::new(extractor));
302 self
303 }
304
305 pub fn with_streaming_policy(mut self, policy: StreamingPolicy) -> Self {
306 self.streaming_policy = policy;
307 self
308 }
309}
310
311impl Default for CachePolicy {
312 fn default() -> Self {
313 Self {
314 ttl: Duration::from_secs(60),
315 negative_ttl: Duration::from_secs(5),
316 stale_while_revalidate: Duration::from_secs(0),
317 refresh_before: Duration::from_secs(0),
318 max_body_size: None,
319 min_body_size: None,
320 cache_statuses: HashSet::from([200, 203, 300, 301, 404]),
321 respect_cache_control: true,
322 method_predicate: None,
323 header_allowlist: None,
324 allow_streaming_bodies: false,
325 compression: CompressionConfig::default(),
326 ml_logging: MLLoggingConfig::default(),
327 tag_policy: TagPolicy::default(),
328 tag_extractor: None,
329 streaming_policy: StreamingPolicy::default(),
330 }
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
338 use std::collections::HashSet;
339
340 #[test]
341 fn ttl_for_prefers_primary_ttl_for_cacheable_status() {
342 let policy = CachePolicy::default().with_ttl(Duration::from_secs(123));
343 assert_eq!(
344 policy.ttl_for(StatusCode::OK),
345 Some(Duration::from_secs(123))
346 );
347 }
348
349 #[test]
350 fn ttl_for_uses_negative_ttl_for_client_error() {
351 let policy = CachePolicy::default().with_negative_ttl(Duration::from_secs(9));
352 assert_eq!(
353 policy.ttl_for(StatusCode::BAD_REQUEST),
354 Some(Duration::from_secs(9))
355 );
356 assert_eq!(policy.ttl_for(StatusCode::INTERNAL_SERVER_ERROR), None);
357 }
358
359 #[test]
360 fn headers_to_cache_respects_allowlist() {
361 let mut headers = HeaderMap::new();
362 headers.insert(
363 HeaderName::from_static("content-type"),
364 HeaderValue::from_static("text/plain"),
365 );
366 headers.insert(
367 HeaderName::from_static("x-cacheable"),
368 HeaderValue::from_static("yes"),
369 );
370
371 let policy = CachePolicy::default().with_header_allowlist(["content-type"]);
372 let cached = policy.headers_to_cache(&headers);
373
374 let expected = vec![("content-type".to_owned(), b"text/plain".to_vec())];
375 assert_eq!(cached, expected);
376 }
377
378 #[test]
379 fn method_predicate_overrides_default_behavior() {
380 let policy = CachePolicy::default().with_method_predicate(|method| method == Method::POST);
381 assert!(!policy.should_cache_method(&Method::GET));
382 assert!(policy.should_cache_method(&Method::POST));
383 }
384
385 #[test]
386 fn with_statuses_updates_allowlist() {
387 let policy = CachePolicy::default().with_statuses([201, 202]);
388 assert_eq!(policy.ttl_for(StatusCode::CREATED), Some(policy.ttl()));
389 assert_eq!(policy.ttl_for(StatusCode::OK), None);
390 assert_eq!(policy.cache_statuses, HashSet::from([201_u16, 202_u16]));
391 }
392}