1use std::error::Error as StdError;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4use std::time::{Duration, SystemTime};
5
6use bytes::Bytes;
7use dashmap::mapref::entry::Entry;
8use dashmap::DashMap;
9use futures_util::future::BoxFuture;
10use http::header::{CACHE_CONTROL, PRAGMA};
11use http::{HeaderMap, Method, Request, Response, Uri};
12use http_body::Body;
13use http_body_util::combinators::BoxBody;
14use http_body_util::{BodyExt, Full};
15use tokio::sync::{Mutex, OwnedMutexGuard};
16use tower::{Layer, Service, ServiceExt};
17
18#[cfg(feature = "metrics")]
19use metrics::{counter, histogram};
20
21use crate::backend::memory::InMemoryBackend;
22use crate::backend::{CacheBackend, CacheEntry, CacheRead};
23use crate::chunks::{ChunkCache, ChunkMetadata};
24#[cfg(feature = "compression")]
25use crate::policy::CompressionStrategy;
26use crate::policy::{CachePolicy, CompressionConfig};
27use crate::range::{is_partial_content, parse_range_header, RangeHandling};
28use crate::refresh::{AutoRefreshConfig, RefreshCallback, RefreshManager, RefreshMetadata};
29#[cfg(feature = "tracing")]
30use crate::streaming::extract_size_info;
31use crate::streaming::{should_stream, StreamingDecision};
32
33pub type BoxError = Box<dyn StdError + Send + Sync>;
34
35pin_project_lite::pin_project! {
36 pub struct SyncBoxBody {
42 #[pin]
43 inner: BoxBody<Bytes, BoxError>,
44 }
45}
46
47impl SyncBoxBody {
48 pub fn new(inner: BoxBody<Bytes, BoxError>) -> Self {
50 Self { inner }
51 }
52}
53
54unsafe impl Sync for SyncBoxBody {}
57
58impl Body for SyncBoxBody {
59 type Data = Bytes;
60 type Error = BoxError;
61
62 fn poll_frame(
63 self: std::pin::Pin<&mut Self>,
64 cx: &mut std::task::Context<'_>,
65 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
66 self.project().inner.poll_frame(cx)
67 }
68
69 fn is_end_stream(&self) -> bool {
70 self.inner.is_end_stream()
71 }
72
73 fn size_hint(&self) -> http_body::SizeHint {
74 self.inner.size_hint()
75 }
76}
77
78type KeyExtractorFn = Arc<dyn Fn(&Method, &Uri) -> Option<String> + Send + Sync>;
80
81#[derive(Clone)]
90pub struct CacheLayer<B> {
91 backend: B,
92 policy: CachePolicy,
93 key_extractor: KeyExtractor,
94 locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
95 refresh_manager: Option<Arc<RefreshManager>>,
96 chunk_cache: Option<Arc<ChunkCache>>,
97}
98
99#[derive(Clone)]
105pub struct KeyExtractor {
106 inner: KeyExtractorFn,
107}
108
109impl KeyExtractor {
110 pub fn path_and_query() -> Self {
112 Self {
113 inner: Arc::new(|method: &Method, uri: &Uri| {
114 if matches!(method, &Method::GET | &Method::HEAD) {
115 let mut key = uri.path().to_owned();
116 if let Some(query) = uri.query() {
117 key.push('?');
118 key.push_str(query);
119 }
120 Some(key)
121 } else {
122 None
123 }
124 }),
125 }
126 }
127
128 pub fn path() -> Self {
129 Self {
130 inner: Arc::new(|method: &Method, uri: &Uri| {
131 if matches!(method, &Method::GET | &Method::HEAD) {
132 Some(uri.path().to_owned())
133 } else {
134 None
135 }
136 }),
137 }
138 }
139
140 pub fn custom<F>(func: F) -> Self
141 where
142 F: Fn(&Method, &Uri) -> Option<String> + Send + Sync + 'static,
143 {
144 Self {
145 inner: Arc::new(func),
146 }
147 }
148
149 pub fn extract(&self, method: &Method, uri: &Uri) -> Option<String> {
153 (self.inner)(method, uri)
154 }
155}
156
157impl Default for KeyExtractor {
158 fn default() -> Self {
159 Self::path_and_query()
160 }
161}
162
163pub struct CacheLayerBuilder<B> {
165 backend: B,
166 policy: CachePolicy,
167 key_extractor: KeyExtractor,
168 auto_refresh_config: Option<AutoRefreshConfig>,
169}
170
171impl<B> CacheLayerBuilder<B>
172where
173 B: CacheBackend,
174{
175 pub fn new(backend: B) -> Self {
176 Self {
177 backend,
178 policy: CachePolicy::default(),
179 key_extractor: KeyExtractor::default(),
180 auto_refresh_config: None,
181 }
182 }
183
184 pub fn policy(mut self, policy: CachePolicy) -> Self {
186 self.policy = policy;
187 self
188 }
189
190 pub fn ttl(mut self, ttl: Duration) -> Self {
192 self.policy = self.policy.with_ttl(ttl);
193 self
194 }
195
196 pub fn negative_ttl(mut self, ttl: Duration) -> Self {
198 self.policy = self.policy.with_negative_ttl(ttl);
199 self
200 }
201
202 pub fn stale_while_revalidate(mut self, duration: Duration) -> Self {
203 self.policy = self.policy.with_stale_while_revalidate(duration);
204 self
205 }
206
207 pub fn refresh_before(mut self, duration: Duration) -> Self {
208 self.policy = self.policy.with_refresh_before(duration);
209 self
210 }
211
212 pub fn max_body_size(mut self, size: Option<usize>) -> Self {
213 self.policy = self.policy.with_max_body_size(size);
214 self
215 }
216
217 pub fn min_body_size(mut self, size: Option<usize>) -> Self {
218 self.policy = self.policy.with_min_body_size(size);
219 self
220 }
221
222 pub fn allow_streaming_bodies(mut self, allow: bool) -> Self {
223 self.policy = self.policy.with_allow_streaming_bodies(allow);
224 self
225 }
226
227 pub fn compression(mut self, config: CompressionConfig) -> Self {
228 self.policy = self.policy.with_compression(config);
229 self
230 }
231
232 pub fn respect_cache_control(mut self, enabled: bool) -> Self {
233 self.policy = self.policy.with_respect_cache_control(enabled);
234 self
235 }
236
237 pub fn statuses(mut self, statuses: impl IntoIterator<Item = u16>) -> Self {
238 self.policy = self.policy.with_statuses(statuses);
239 self
240 }
241
242 pub fn method_predicate<F>(mut self, predicate: F) -> Self
243 where
244 F: Fn(&Method) -> bool + Send + Sync + 'static,
245 {
246 self.policy = self.policy.with_method_predicate(predicate);
247 self
248 }
249
250 pub fn header_allowlist<I, S>(mut self, headers: I) -> Self
251 where
252 I: IntoIterator<Item = S>,
253 S: Into<String>,
254 {
255 self.policy = self.policy.with_header_allowlist(headers);
256 self
257 }
258
259 pub fn key_extractor(mut self, extractor: KeyExtractor) -> Self {
260 self.key_extractor = extractor;
261 self
262 }
263
264 pub fn auto_refresh(mut self, config: AutoRefreshConfig) -> Self {
269 self.auto_refresh_config = Some(config);
270 self
271 }
272
273 pub fn build(self) -> CacheLayer<B> {
274 let refresh_manager = self
275 .auto_refresh_config
276 .filter(|cfg| cfg.enabled)
277 .map(|cfg| Arc::new(RefreshManager::new(cfg)));
278
279 let chunk_cache = if self.policy.streaming_policy().enable_chunk_cache {
281 Some(Arc::new(ChunkCache::new(
282 self.policy.streaming_policy().chunk_size,
283 )))
284 } else {
285 None
286 };
287
288 CacheLayer {
289 backend: self.backend,
290 policy: self.policy,
291 key_extractor: self.key_extractor,
292 locks: Arc::new(DashMap::new()),
293 refresh_manager,
294 chunk_cache,
295 }
296 }
297}
298
299impl CacheLayer<InMemoryBackend> {
300 pub fn new_in_memory(max_capacity: u64) -> Self {
302 CacheLayerBuilder::new(InMemoryBackend::new(max_capacity)).build()
303 }
304}
305
306impl<B> CacheLayer<B>
307where
308 B: CacheBackend,
309{
310 pub fn new(backend: B) -> Self {
312 CacheLayerBuilder::new(backend).build()
313 }
314
315 pub fn builder(backend: B) -> CacheLayerBuilder<B> {
317 CacheLayerBuilder::new(backend)
318 }
319
320 pub fn with_policy(mut self, policy: CachePolicy) -> Self {
321 self.policy = policy;
322 self
323 }
324
325 pub fn with_ttl(mut self, ttl: Duration) -> Self {
326 self.policy = self.policy.clone().with_ttl(ttl);
327 self
328 }
329
330 pub fn with_negative_ttl(mut self, ttl: Duration) -> Self {
331 self.policy = self.policy.clone().with_negative_ttl(ttl);
332 self
333 }
334
335 pub fn with_stale_while_revalidate(mut self, duration: Duration) -> Self {
336 self.policy = self.policy.clone().with_stale_while_revalidate(duration);
337 self
338 }
339
340 pub fn with_refresh_before(mut self, duration: Duration) -> Self {
341 self.policy = self.policy.clone().with_refresh_before(duration);
342 self
343 }
344
345 pub fn with_max_body_size(mut self, size: Option<usize>) -> Self {
346 self.policy = self.policy.clone().with_max_body_size(size);
347 self
348 }
349
350 pub fn with_min_body_size(mut self, size: Option<usize>) -> Self {
351 self.policy = self.policy.clone().with_min_body_size(size);
352 self
353 }
354
355 pub fn with_allow_streaming_bodies(mut self, allow: bool) -> Self {
356 self.policy = self.policy.clone().with_allow_streaming_bodies(allow);
357 self
358 }
359
360 pub fn with_compression(mut self, config: CompressionConfig) -> Self {
361 self.policy = self.policy.clone().with_compression(config);
362 self
363 }
364
365 pub fn with_respect_cache_control(mut self, enabled: bool) -> Self {
366 self.policy = self.policy.clone().with_respect_cache_control(enabled);
367 self
368 }
369
370 pub fn with_cache_statuses(mut self, statuses: impl IntoIterator<Item = u16>) -> Self {
371 self.policy = self.policy.clone().with_statuses(statuses);
372 self
373 }
374
375 pub fn with_method_predicate<F>(mut self, predicate: F) -> Self
376 where
377 F: Fn(&Method) -> bool + Send + Sync + 'static,
378 {
379 self.policy = self.policy.clone().with_method_predicate(predicate);
380 self
381 }
382
383 pub fn with_header_allowlist<I, S>(mut self, headers: I) -> Self
384 where
385 I: IntoIterator<Item = S>,
386 S: Into<String>,
387 {
388 self.policy = self.policy.clone().with_header_allowlist(headers);
389 self
390 }
391
392 pub fn with_key_extractor(mut self, extractor: KeyExtractor) -> Self {
393 self.key_extractor = extractor;
394 self
395 }
396
397 pub async fn init_auto_refresh<S, ResBody>(&self, service: S) -> Result<(), String>
412 where
413 S: Service<Request<()>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
414 S::Future: Send + 'static,
415 S::Error: Into<BoxError> + Send,
416 ResBody: Body<Data = Bytes> + Send + 'static,
417 ResBody::Error: Into<BoxError> + Send,
418 B: Clone,
419 {
420 if let Some(ref manager) = self.refresh_manager {
421 let callback = Arc::new(CacheRefreshCallback::new(
422 service,
423 self.backend.clone(),
424 self.policy.clone(),
425 self.key_extractor.clone(),
426 ));
427 manager.start(callback).await
428 } else {
429 Ok(())
430 }
431 }
432}
433
434impl<S, B> Layer<S> for CacheLayer<B>
435where
436 B: CacheBackend,
437{
438 type Service = CacheService<S, B>;
439
440 fn layer(&self, inner: S) -> Self::Service {
441 CacheService {
442 inner,
443 backend: self.backend.clone(),
444 policy: self.policy.clone(),
445 key_extractor: self.key_extractor.clone(),
446 locks: self.locks.clone(),
447 refresh_manager: self.refresh_manager.clone(),
448 chunk_cache: self.chunk_cache.clone(),
449 }
450 }
451}
452
453impl<B> Drop for CacheLayer<B> {
454 fn drop(&mut self) {
455 if let Some(manager) = &self.refresh_manager {
458 let manager = manager.clone();
459 if let Ok(handle) = tokio::runtime::Handle::try_current() {
462 handle.spawn(async move {
463 manager.shutdown().await;
464 });
465 }
466 }
467 }
468}
469
470#[derive(Clone)]
471pub struct CacheService<S, B> {
472 inner: S,
473 backend: B,
474 policy: CachePolicy,
475 key_extractor: KeyExtractor,
476 locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
477 refresh_manager: Option<Arc<RefreshManager>>,
478 chunk_cache: Option<Arc<ChunkCache>>,
479}
480
481struct CacheRefreshCallback<S, B> {
483 inner: S,
484 backend: B,
485 policy: CachePolicy,
486}
487
488impl<S, B> CacheRefreshCallback<S, B> {
489 fn new(inner: S, backend: B, policy: CachePolicy, _key_extractor: KeyExtractor) -> Self {
490 Self {
491 inner,
492 backend,
493 policy,
494 }
495 }
496}
497
498impl<S, B, ResBody> RefreshCallback for CacheRefreshCallback<S, B>
499where
500 S: Service<Request<()>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
501 S::Future: Send + 'static,
502 S::Error: Into<BoxError> + Send,
503 ResBody: Body<Data = Bytes> + Send + 'static,
504 ResBody::Error: Into<BoxError> + Send,
505 B: CacheBackend,
506{
507 fn refresh(&self, key: String, metadata: RefreshMetadata) -> crate::refresh::RefreshFuture {
508 let backend = self.backend.clone();
509 let policy = self.policy.clone();
510 let inner = self.inner.clone();
511
512 Box::pin(async move {
513 #[cfg(feature = "tracing")]
514 tracing::debug!(key = %key, uri = %metadata.uri, "Auto-refresh triggered");
515
516 let request = match metadata.try_into_request() {
518 Some(req) => req,
519 None => {
520 #[cfg(feature = "tracing")]
521 tracing::warn!(key = %key, "Failed to reconstruct request for auto-refresh");
522 return Err("Failed to reconstruct request".into());
523 }
524 };
525
526 let service = inner;
528 let response = match service.oneshot(request).await {
529 Ok(resp) => resp,
530 Err(_err) => {
531 #[cfg(feature = "tracing")]
532 tracing::error!(key = %key, "Service error during auto-refresh");
533 return Err("Service error during refresh".into());
534 }
535 };
536
537 let (parts, body) = response.into_parts();
538
539 let collected = match BodyExt::collect(body).await {
541 Ok(c) => c,
542 Err(_err) => {
543 #[cfg(feature = "tracing")]
544 tracing::error!(key = %key, "Body collection error during auto-refresh");
545 return Err("Body collection error".into());
546 }
547 };
548
549 let cache_bytes = collected.to_bytes();
550
551 let body_too_large = policy
553 .max_body_size()
554 .is_some_and(|max| cache_bytes.len() > max);
555 let body_too_small = policy
556 .min_body_size()
557 .is_some_and(|min| cache_bytes.len() < min);
558
559 if body_too_large || body_too_small {
560 return Ok(()); }
562
563 if let Some(ttl) = policy.ttl_for(parts.status) {
565 if !ttl.is_zero() {
566 let stale_for = policy.stale_while_revalidate();
567 let headers_to_cache = policy.headers_to_cache(&parts.headers);
568 let (compressed_bytes, _compressed) =
569 maybe_compress(cache_bytes, policy.compression());
570
571 let entry = CacheEntry::new(
572 parts.status,
573 parts.version,
574 headers_to_cache,
575 compressed_bytes,
576 );
577
578 if let Err(_err) = backend.set(key.clone(), entry, ttl, stale_for).await {
579 #[cfg(feature = "tracing")]
580 tracing::error!(key = %key, "Failed to store refreshed entry");
581 return Err("Failed to store entry".into());
582 }
583 }
584 }
585
586 Ok(())
587 })
588 }
589}
590
591impl<S, B, ReqBody, ResBody> Service<Request<ReqBody>> for CacheService<S, B>
597where
598 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
599 S::Future: Send + 'static,
600 S::Error: Into<BoxError> + Send,
601 ReqBody: Send + 'static,
602 ResBody: Body<Data = Bytes> + Send + Sync + 'static,
603 ResBody::Error: Into<BoxError> + Send,
604 B: CacheBackend,
605{
606 type Response = Response<SyncBoxBody>;
607 type Error = BoxError;
608 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
609
610 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
611 self.inner.poll_ready(cx).map_err(Into::into)
612 }
613
614 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
615 let method = req.method().clone();
616 let uri = req.uri().clone();
617 let should_cache_method = self.policy.should_cache_method(&method);
618 let request_bypass =
619 self.policy.respect_cache_control() && cache_control_disallows(req.headers());
620 let key = if should_cache_method && !request_bypass {
621 self.key_extractor.extract(&method, &uri)
622 } else {
623 None
624 };
625
626 let backend = self.backend.clone();
627 let policy = self.policy.clone();
628 let locks = self.locks.clone();
629 let inner = self.inner.clone();
630 let stale_window = policy.stale_while_revalidate();
631 let refresh_before = policy.refresh_before();
632 let refresh_manager = self.refresh_manager.clone();
633 let chunk_cache = self.chunk_cache.clone();
634
635 let range_request = parse_range_header(req.headers());
637
638 let refresh_metadata = if refresh_manager.is_some() && key.is_some() {
640 Some(RefreshMetadata::from_request(&req))
641 } else {
642 None
643 };
644
645 Box::pin(async move {
646 #[cfg(feature = "tracing")]
647 tracing::debug!(method = %method, uri = %uri, "cache_call");
648
649 if let (Some(range_req), Some(ref chunk_cache), Some(ref key_ref)) =
651 (range_request.as_ref(), &chunk_cache, &key)
652 {
653 if let Some(entry) = chunk_cache.get(key_ref) {
654 if let Some(normalized) = range_req.normalize(entry.metadata.total_size) {
656 let end = normalized.end.unwrap_or(entry.metadata.total_size - 1);
657
658 if let Some(range_data) = entry.get_range(normalized.start, end) {
660 #[cfg(feature = "metrics")]
661 counter!("tower_http_cache.chunk_cache_hit").increment(1);
662
663 #[cfg(feature = "tracing")]
664 tracing::debug!(
665 key = %key_ref,
666 start = normalized.start,
667 end = end,
668 "chunk_cache_hit"
669 );
670
671 let mut response = Response::builder()
673 .status(http::StatusCode::PARTIAL_CONTENT)
674 .body(SyncBoxBody::new(
675 Full::from(range_data).map_err(Into::into).boxed(),
676 ))
677 .unwrap();
678
679 for (name, value) in &entry.metadata.headers {
681 if let (Ok(header_name), Ok(header_value)) = (
682 http::header::HeaderName::from_bytes(name.as_bytes()),
683 http::header::HeaderValue::from_bytes(value),
684 ) {
685 response.headers_mut().insert(header_name, header_value);
686 }
687 }
688
689 let content_range = format!(
691 "bytes {}-{}/{}",
692 normalized.start, end, entry.metadata.total_size
693 );
694 response.headers_mut().insert(
695 http::header::CONTENT_RANGE,
696 http::header::HeaderValue::from_str(&content_range).unwrap(),
697 );
698
699 let content_length = (end - normalized.start + 1).to_string();
701 response.headers_mut().insert(
702 http::header::CONTENT_LENGTH,
703 http::header::HeaderValue::from_str(&content_length).unwrap(),
704 );
705
706 return Ok(response);
707 }
708 }
709 }
710
711 #[cfg(feature = "metrics")]
712 counter!("tower_http_cache.chunk_cache_miss").increment(1);
713 }
714
715 let mut stale_entry: Option<CacheEntry> = None;
716 if let Some(ref key_ref) = key {
717 if let Ok(Some(hit)) = backend.get(key_ref).await {
718 match classify_hit(hit, stale_window, refresh_before) {
719 HitState::Fresh(entry) => {
720 #[cfg(feature = "metrics")]
721 counter!("tower_http_cache.hit").increment(1);
722
723 if let Some(ref manager) = refresh_manager {
725 manager.tracker().record_hit(key_ref);
726 }
727
728 return Ok(entry.into_response());
729 }
730 HitState::Stale(entry) => {
731 #[cfg(feature = "metrics")]
732 counter!("tower_http_cache.stale_hit").increment(1);
733
734 if let Some(ref manager) = refresh_manager {
736 manager.tracker().record_hit(key_ref);
737 }
738
739 stale_entry = Some(entry);
740 }
741 HitState::Expired => {}
742 }
743 }
744 }
745
746 let mut primary_guard: Option<StampedeGuard> = None;
747 if let Some(ref key_ref) = key {
748 match StampedeGuard::acquire_handle(locks.clone(), key_ref.clone()).await {
749 StampedeHandle::Primary(guard) => {
750 primary_guard = Some(guard);
751 }
752 StampedeHandle::Secondary(lock) => {
753 if let Some(entry) = stale_entry.clone() {
754 #[cfg(feature = "metrics")]
755 counter!("tower_http_cache.stale_served").increment(1);
756 return Ok(entry.into_response());
757 }
758
759 let secondary_guard = lock.lock_owned().await;
760 drop(secondary_guard);
761
762 if let Ok(Some(hit)) = backend.get(key_ref).await {
763 match classify_hit(hit, stale_window, refresh_before) {
764 HitState::Fresh(entry) => {
765 #[cfg(feature = "metrics")]
766 counter!("tower_http_cache.hit_after_wait").increment(1);
767 return Ok(entry.into_response());
768 }
769 HitState::Stale(entry) => {
770 #[cfg(feature = "metrics")]
771 counter!("tower_http_cache.stale_served").increment(1);
772 return Ok(entry.into_response());
773 }
774 HitState::Expired => {}
775 }
776 }
777
778 if let StampedeHandle::Primary(guard) =
779 StampedeGuard::acquire_handle(locks.clone(), key_ref.clone()).await
780 {
781 primary_guard = Some(guard);
782 }
783 }
784 }
785 }
786
787 #[cfg(feature = "metrics")]
788 counter!("tower_http_cache.miss").increment(1);
789
790 #[cfg(feature = "metrics")]
791 let start = std::time::Instant::now();
792 let service = inner;
793 let response = service.oneshot(req).await.map_err(|err| err.into())?;
794 #[cfg(feature = "metrics")]
795 histogram!("tower_http_cache.backend_latency").record(start.elapsed().as_secs_f64());
796
797 let (parts, body) = response.into_parts();
798
799 let content_type = parts
801 .headers
802 .get(http::header::CONTENT_TYPE)
803 .and_then(|v| v.to_str().ok());
804
805 let content_length = parts
806 .headers
807 .get(http::header::CONTENT_LENGTH)
808 .and_then(|v| v.to_str().ok())
809 .and_then(|v| v.parse::<u64>().ok());
810
811 let size_hint = body.size_hint();
812
813 let streaming_decision = should_stream(
814 policy.streaming_policy(),
815 &size_hint,
816 content_type,
817 content_length,
818 );
819
820 let is_range_request = parse_range_header(&parts.headers).is_some();
822 let is_partial_response = is_partial_content(parts.status);
823 let range_handling = policy.streaming_policy().range_handling;
824
825 if (is_range_request || is_partial_response)
827 && range_handling == RangeHandling::PassThrough
828 {
829 #[cfg(feature = "metrics")]
830 counter!("tower_http_cache.range_request_passthrough").increment(1);
831
832 #[cfg(feature = "tracing")]
833 tracing::debug!(
834 method = %method,
835 uri = %uri,
836 is_range = is_range_request,
837 is_partial = is_partial_response,
838 "range_request_passthrough"
839 );
840
841 let boxed_body = SyncBoxBody::new(body.map_err(Into::into).boxed());
843 drop(primary_guard);
844 return Ok(Response::from_parts(parts, boxed_body));
845 }
846
847 match streaming_decision {
849 StreamingDecision::SkipCache | StreamingDecision::StreamThrough => {
850 #[cfg(feature = "metrics")]
852 counter!("tower_http_cache.streaming_passthrough").increment(1);
853
854 #[cfg(feature = "tracing")]
855 tracing::debug!(
856 method = %method,
857 uri = %uri,
858 decision = ?streaming_decision,
859 content_type = ?content_type,
860 size = ?extract_size_info(&size_hint, content_length),
861 "streaming_passthrough"
862 );
863
864 let boxed_body = SyncBoxBody::new(body.map_err(Into::into).boxed());
866 drop(primary_guard);
867 return Ok(Response::from_parts(parts, boxed_body));
868 }
869 _ => {}
870 }
871
872 let streaming = body.size_hint().upper().is_none();
873 if streaming && !policy.allow_streaming_bodies() {
874 #[cfg(feature = "metrics")]
875 counter!("tower_http_cache.streaming_skip").increment(1);
876 }
877
878 let collected = BodyExt::collect(body).await.map_err(|err| err.into())?;
879 let cache_bytes = collected.to_bytes();
880 let response_bytes = cache_bytes.clone();
881
882 if let (Some(ref chunk_cache), Some(ref key_ref)) = (&chunk_cache, &key) {
884 let streaming_policy = policy.streaming_policy();
885
886 if streaming_policy.enable_chunk_cache
887 && cache_bytes.len() as u64 >= streaming_policy.min_chunk_file_size
888 && parts.status.is_success()
889 {
890 #[cfg(feature = "tracing")]
891 tracing::debug!(
892 key = %key_ref,
893 size = cache_bytes.len(),
894 chunk_size = streaming_policy.chunk_size,
895 "populating_chunk_cache"
896 );
897
898 let metadata = ChunkMetadata {
900 total_size: cache_bytes.len() as u64,
901 content_type: parts
902 .headers
903 .get(http::header::CONTENT_TYPE)
904 .and_then(|v| v.to_str().ok())
905 .unwrap_or("application/octet-stream")
906 .to_string(),
907 etag: parts
908 .headers
909 .get(http::header::ETAG)
910 .and_then(|v| v.to_str().ok())
911 .map(|s| s.to_string()),
912 last_modified: parts
913 .headers
914 .get(http::header::LAST_MODIFIED)
915 .and_then(|v| v.to_str().ok())
916 .map(|s| s.to_string()),
917 status: parts.status,
918 version: parts.version,
919 headers: policy.headers_to_cache(&parts.headers),
920 };
921
922 let entry = chunk_cache.get_or_create(key_ref.clone(), metadata);
924
925 let chunk_size = streaming_policy.chunk_size;
927 let mut offset = 0;
928 let mut chunk_index = 0;
929
930 while offset < cache_bytes.len() {
931 let end = std::cmp::min(offset + chunk_size, cache_bytes.len());
932 let chunk = cache_bytes.slice(offset..end);
933 entry.add_chunk(chunk_index, chunk);
934
935 offset = end;
936 chunk_index += 1;
937 }
938
939 #[cfg(feature = "metrics")]
940 counter!("tower_http_cache.chunk_cache_stored").increment(1);
941
942 #[cfg(feature = "tracing")]
943 tracing::debug!(
944 key = %key_ref,
945 chunks = chunk_index,
946 "chunk_cache_populated"
947 );
948 }
949 }
950
951 let cache_control_block =
952 policy.respect_cache_control() && cache_control_disallows(&parts.headers);
953 let body_too_large = policy
954 .max_body_size()
955 .is_some_and(|max| cache_bytes.len() > max);
956 let body_too_small = policy
957 .min_body_size()
958 .is_some_and(|min| cache_bytes.len() < min);
959
960 let should_store = key.is_some()
961 && !cache_control_block
962 && !body_too_large
963 && !body_too_small
964 && (policy.allow_streaming_bodies() || !streaming);
965
966 let headers_to_cache = if should_store {
967 Some(policy.headers_to_cache(&parts.headers))
968 } else {
969 None
970 };
971
972 let version = parts.version;
973 let status = parts.status;
974
975 if should_store {
976 if let Some(key_ref) = &key {
977 if let Some(ttl) = policy.ttl_for(status) {
978 if !ttl.is_zero() {
979 let stale_for = policy.stale_while_revalidate();
980 let (compressed_bytes, compressed) =
981 maybe_compress(cache_bytes.clone(), policy.compression());
982 if compressed {
983 #[cfg(feature = "metrics")]
984 counter!("tower_http_cache.compressed").increment(1);
985 }
986 let entry = CacheEntry::new(
987 status,
988 version,
989 headers_to_cache.unwrap(),
990 compressed_bytes,
991 );
992 if backend
993 .set(key_ref.clone(), entry, ttl, stale_for)
994 .await
995 .is_err()
996 {
997 #[cfg(feature = "metrics")]
998 counter!("tower_http_cache.store_error").increment(1);
999 } else {
1000 #[cfg(feature = "metrics")]
1001 counter!("tower_http_cache.store").increment(1);
1002
1003 if let (Some(ref manager), Some(metadata)) =
1005 (&refresh_manager, refresh_metadata)
1006 {
1007 manager.store_metadata(key_ref.clone(), metadata);
1008 }
1009 }
1010 }
1011 }
1012 }
1013 } else {
1014 #[cfg(feature = "metrics")]
1015 counter!("tower_http_cache.store_skipped").increment(1);
1016 }
1017
1018 drop(primary_guard);
1019
1020 let full_body = Full::from(response_bytes);
1022 let boxed_body = SyncBoxBody::new(full_body.map_err(Into::into).boxed());
1023 Ok(Response::from_parts(parts, boxed_body))
1024 })
1025 }
1026}
1027
1028struct StampedeGuard {
1029 key: String,
1030 locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
1031 lock: Arc<Mutex<()>>,
1032 _guard: OwnedMutexGuard<()>,
1033}
1034
1035enum StampedeHandle {
1036 Primary(StampedeGuard),
1037 Secondary(Arc<Mutex<()>>),
1038}
1039
1040impl StampedeGuard {
1041 async fn acquire_handle(
1042 locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
1043 key: String,
1044 ) -> StampedeHandle {
1045 let handle = match locks.entry(key.clone()) {
1046 Entry::Occupied(entry) => StampedeHandle::Secondary(entry.get().clone()),
1047 Entry::Vacant(entry) => {
1048 let lock = Arc::new(Mutex::new(()));
1049 entry.insert(lock.clone());
1050 let guard = lock.clone().lock_owned().await;
1051 let locks_clone = locks.clone();
1052 StampedeHandle::Primary(StampedeGuard {
1053 key,
1054 locks: locks_clone,
1055 lock,
1056 _guard: guard,
1057 })
1058 }
1059 };
1060 handle
1061 }
1062}
1063
1064impl Drop for StampedeGuard {
1065 fn drop(&mut self) {
1066 if let Some(current) = self.locks.get(&self.key) {
1067 let should_remove = Arc::ptr_eq(&self.lock, current.value());
1068 drop(current);
1069 if should_remove {
1070 self.locks.remove(&self.key);
1071 }
1072 }
1073 }
1074}
1075
1076fn classify_hit(hit: CacheRead, stale_window: Duration, refresh_before: Duration) -> HitState {
1077 let now = SystemTime::now();
1078 let CacheRead {
1079 entry,
1080 expires_at,
1081 stale_until,
1082 } = hit;
1083
1084 if let Some(expires_at) = expires_at {
1085 if expires_at > now {
1086 if refresh_before > Duration::ZERO {
1087 if let Some(threshold) = expires_at.checked_sub(refresh_before) {
1088 if now >= threshold {
1089 return HitState::Stale(entry);
1090 }
1091 } else {
1092 return HitState::Stale(entry);
1093 }
1094 }
1095 return HitState::Fresh(entry);
1096 }
1097 }
1098
1099 if stale_window > Duration::ZERO {
1100 if let Some(stale_until) = stale_until {
1101 if stale_until > now {
1102 return HitState::Stale(entry);
1103 }
1104 }
1105 }
1106
1107 HitState::Expired
1108}
1109
1110#[derive(Debug)]
1111enum HitState {
1112 Fresh(CacheEntry),
1113 Stale(CacheEntry),
1114 Expired,
1115}
1116
1117fn cache_control_disallows(headers: &HeaderMap) -> bool {
1118 headers
1119 .get_all(CACHE_CONTROL)
1120 .iter()
1121 .filter_map(|value| value.to_str().ok())
1122 .flat_map(|value| value.split(','))
1123 .map(|token| token.trim().to_ascii_lowercase())
1124 .any(|token| matches!(token.as_str(), "no-store" | "no-cache" | "private"))
1125 || headers
1126 .get(PRAGMA)
1127 .and_then(|value| value.to_str().ok())
1128 .map(|value| value.to_ascii_lowercase().contains("no-cache"))
1129 .unwrap_or(false)
1130}
1131
1132#[cfg(feature = "compression")]
1133fn maybe_compress(bytes: Bytes, config: CompressionConfig) -> (Bytes, bool) {
1134 use flate2::{write::GzEncoder, Compression};
1135 use std::io::Write;
1136
1137 match config.strategy {
1138 CompressionStrategy::None => (bytes, false),
1139 CompressionStrategy::Gzip => {
1140 if bytes.len() < config.min_size {
1141 return (bytes, false);
1142 }
1143 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
1144 if encoder.write_all(&bytes).is_err() {
1145 return (bytes, false);
1146 }
1147 match encoder.finish() {
1148 Ok(data) => (Bytes::from(data), true),
1149 Err(_) => (bytes, false),
1150 }
1151 }
1152 }
1153}
1154
1155#[cfg(not(feature = "compression"))]
1156fn maybe_compress(bytes: Bytes, _config: CompressionConfig) -> (Bytes, bool) {
1157 let _ = _config;
1158 (bytes, false)
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163 use super::*;
1164 use crate::backend::CacheEntry;
1165 use bytes::Bytes;
1166 use http::{HeaderValue, StatusCode, Version};
1167 use tokio::task::yield_now;
1168
1169 fn mock_entry() -> CacheEntry {
1170 CacheEntry::new(
1171 StatusCode::OK,
1172 Version::HTTP_11,
1173 Vec::new(),
1174 Bytes::from_static(b"body"),
1175 )
1176 }
1177
1178 #[test]
1179 fn classify_hit_marks_entry_fresh_when_not_near_expiry() {
1180 let now = SystemTime::now();
1181 let hit = CacheRead {
1182 entry: mock_entry(),
1183 expires_at: Some(now + Duration::from_secs(10)),
1184 stale_until: Some(now + Duration::from_secs(20)),
1185 };
1186
1187 match classify_hit(hit, Duration::from_secs(5), Duration::from_secs(1)) {
1188 HitState::Fresh(_) => {}
1189 other => panic!("expected fresh entry, got {:?}", other),
1190 }
1191 }
1192
1193 #[test]
1194 fn classify_hit_marks_entry_stale_when_within_refresh_window() {
1195 let now = SystemTime::now();
1196 let hit = CacheRead {
1197 entry: mock_entry(),
1198 expires_at: Some(now + Duration::from_secs(2)),
1199 stale_until: Some(now + Duration::from_secs(10)),
1200 };
1201
1202 match classify_hit(hit, Duration::from_secs(5), Duration::from_secs(5)) {
1203 HitState::Stale(_) => {}
1204 other => panic!("expected stale entry, got {:?}", other),
1205 }
1206 }
1207
1208 #[test]
1209 fn classify_hit_marks_entry_stale_when_within_stale_window() {
1210 let now = SystemTime::now();
1211 let hit = CacheRead {
1212 entry: mock_entry(),
1213 expires_at: Some(now - Duration::from_secs(1)),
1214 stale_until: Some(now + Duration::from_secs(1)),
1215 };
1216
1217 match classify_hit(hit, Duration::from_secs(2), Duration::from_secs(0)) {
1218 HitState::Stale(_) => {}
1219 other => panic!("expected stale entry, got {:?}", other),
1220 }
1221 }
1222
1223 #[test]
1224 fn classify_hit_marks_entry_expired_after_stale_window() {
1225 let now = SystemTime::now();
1226 let hit = CacheRead {
1227 entry: mock_entry(),
1228 expires_at: Some(now - Duration::from_secs(5)),
1229 stale_until: Some(now - Duration::from_secs(1)),
1230 };
1231
1232 match classify_hit(hit, Duration::from_secs(5), Duration::from_secs(0)) {
1233 HitState::Expired => {}
1234 other => panic!("expected expired entry, got {:?}", other),
1235 }
1236 }
1237
1238 #[test]
1239 fn cache_control_disallows_detects_no_cache_directives() {
1240 let mut headers = HeaderMap::new();
1241 headers.insert(
1242 CACHE_CONTROL,
1243 HeaderValue::from_static("max-age=0, no-cache"),
1244 );
1245 assert!(cache_control_disallows(&headers));
1246
1247 let mut pragma_only = HeaderMap::new();
1248 pragma_only.insert(PRAGMA, HeaderValue::from_static("no-cache"));
1249 assert!(cache_control_disallows(&pragma_only));
1250 }
1251
1252 #[tokio::test]
1253 async fn stampede_guard_drop_removes_lock_entry() {
1254 let locks = Arc::new(DashMap::new());
1255 let key = "key".to_string();
1256
1257 match StampedeGuard::acquire_handle(locks.clone(), key.clone()).await {
1258 StampedeHandle::Primary(guard) => {
1259 assert!(locks.get(&key).is_some());
1260 drop(guard);
1261 yield_now().await;
1262 assert!(locks.get(&key).is_none());
1263 }
1264 StampedeHandle::Secondary(_) => panic!("expected primary guard"),
1265 }
1266 }
1267
1268 #[test]
1269 fn cache_service_implements_clone() {
1270 use crate::backend::memory::InMemoryBackend;
1271 use tower::service_fn;
1272
1273 fn assert_clone<T: Clone>(_: &T) {}
1275
1276 let backend = InMemoryBackend::new(100);
1277 let layer = CacheLayer::new(backend);
1278 let service = layer.layer(service_fn(|_req: http::Request<()>| async {
1279 Ok::<_, std::convert::Infallible>(http::Response::new(()))
1280 }));
1281
1282 assert_clone(&service);
1284 }
1285}