tower_http_cache/
layer.rs

1use std::error::Error as StdError;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4use std::time::{Duration, SystemTime};
5
6use bytes::Bytes;
7use dashmap::mapref::entry::Entry;
8use dashmap::DashMap;
9use futures_util::future::BoxFuture;
10use http::header::{CACHE_CONTROL, PRAGMA};
11use http::{HeaderMap, Method, Request, Response, Uri};
12use http_body::Body;
13use http_body_util::combinators::BoxBody;
14use http_body_util::{BodyExt, Full};
15use tokio::sync::{Mutex, OwnedMutexGuard};
16use tower::{Layer, Service, ServiceExt};
17
18#[cfg(feature = "metrics")]
19use metrics::{counter, histogram};
20
21use crate::backend::memory::InMemoryBackend;
22use crate::backend::{CacheBackend, CacheEntry, CacheRead};
23use crate::chunks::{ChunkCache, ChunkMetadata};
24#[cfg(feature = "compression")]
25use crate::policy::CompressionStrategy;
26use crate::policy::{CachePolicy, CompressionConfig};
27use crate::range::{is_partial_content, parse_range_header, RangeHandling};
28use crate::refresh::{AutoRefreshConfig, RefreshCallback, RefreshManager, RefreshMetadata};
29#[cfg(feature = "tracing")]
30use crate::streaming::extract_size_info;
31use crate::streaming::{should_stream, StreamingDecision};
32
33pub type BoxError = Box<dyn StdError + Send + Sync>;
34
35pin_project_lite::pin_project! {
36    /// Response body type that implements Sync for Axum compatibility.
37    ///
38    /// This wraps BoxBody and manually implements Sync using the same pattern as Axum.
39    /// The inner body is wrapped with SyncWrapper to satisfy Send + Sync requirements
40    /// while still providing HttpBody implementation.
41    pub struct SyncBoxBody {
42        #[pin]
43        inner: BoxBody<Bytes, BoxError>,
44    }
45}
46
47impl SyncBoxBody {
48    /// Creates a new SyncBoxBody by wrapping a BoxBody.
49    pub fn new(inner: BoxBody<Bytes, BoxError>) -> Self {
50        Self { inner }
51    }
52}
53
54// SAFETY: BoxBody is Send, and we're using this in a single-threaded Tower service context.
55// This is the same pattern used by Axum's Body type.
56unsafe impl Sync for SyncBoxBody {}
57
58impl Body for SyncBoxBody {
59    type Data = Bytes;
60    type Error = BoxError;
61
62    fn poll_frame(
63        self: std::pin::Pin<&mut Self>,
64        cx: &mut std::task::Context<'_>,
65    ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
66        self.project().inner.poll_frame(cx)
67    }
68
69    fn is_end_stream(&self) -> bool {
70        self.inner.is_end_stream()
71    }
72
73    fn size_hint(&self) -> http_body::SizeHint {
74        self.inner.size_hint()
75    }
76}
77
78/// Type alias for the key extractor function
79type KeyExtractorFn = Arc<dyn Fn(&Method, &Uri) -> Option<String> + Send + Sync>;
80
81/// Configurable caching layer for Tower services.
82///
83/// The layer wraps an inner service and caches HTTP responses based on the
84/// configured [`CachePolicy`]. Create instances via [`CacheLayer::builder`]
85/// or [`CacheLayer::new`] for a sensible default policy.
86///
87/// Cloning a `CacheLayer` is cheap and shares the underlying backend and
88/// in-flight stampede locks.
89#[derive(Clone)]
90pub struct CacheLayer<B> {
91    backend: B,
92    policy: CachePolicy,
93    key_extractor: KeyExtractor,
94    locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
95    refresh_manager: Option<Arc<RefreshManager>>,
96    chunk_cache: Option<Arc<ChunkCache>>,
97}
98
99/// Strategy used to turn requests into cache keys.
100///
101/// The layer ships with helpers for common patterns such as
102/// [`KeyExtractor::path_and_query`] and [`KeyExtractor::path`].
103/// You can also provide your own extractor with [`KeyExtractor::custom`].
104#[derive(Clone)]
105pub struct KeyExtractor {
106    inner: KeyExtractorFn,
107}
108
109impl KeyExtractor {
110    /// Builds an extractor that uses `method + path + query` for GET/HEAD requests.
111    pub fn path_and_query() -> Self {
112        Self {
113            inner: Arc::new(|method: &Method, uri: &Uri| {
114                if matches!(method, &Method::GET | &Method::HEAD) {
115                    let mut key = uri.path().to_owned();
116                    if let Some(query) = uri.query() {
117                        key.push('?');
118                        key.push_str(query);
119                    }
120                    Some(key)
121                } else {
122                    None
123                }
124            }),
125        }
126    }
127
128    pub fn path() -> Self {
129        Self {
130            inner: Arc::new(|method: &Method, uri: &Uri| {
131                if matches!(method, &Method::GET | &Method::HEAD) {
132                    Some(uri.path().to_owned())
133                } else {
134                    None
135                }
136            }),
137        }
138    }
139
140    pub fn custom<F>(func: F) -> Self
141    where
142        F: Fn(&Method, &Uri) -> Option<String> + Send + Sync + 'static,
143    {
144        Self {
145            inner: Arc::new(func),
146        }
147    }
148
149    /// Extracts a cache key from the provided request parts.
150    ///
151    /// Returns `None` when the request should be skipped.
152    pub fn extract(&self, method: &Method, uri: &Uri) -> Option<String> {
153        (self.inner)(method, uri)
154    }
155}
156
157impl Default for KeyExtractor {
158    fn default() -> Self {
159        Self::path_and_query()
160    }
161}
162
163/// Builder for configuring [`CacheLayer`] instances.
164pub struct CacheLayerBuilder<B> {
165    backend: B,
166    policy: CachePolicy,
167    key_extractor: KeyExtractor,
168    auto_refresh_config: Option<AutoRefreshConfig>,
169}
170
171impl<B> CacheLayerBuilder<B>
172where
173    B: CacheBackend,
174{
175    pub fn new(backend: B) -> Self {
176        Self {
177            backend,
178            policy: CachePolicy::default(),
179            key_extractor: KeyExtractor::default(),
180            auto_refresh_config: None,
181        }
182    }
183
184    /// Replaces the cache policy with a pre-built value.
185    pub fn policy(mut self, policy: CachePolicy) -> Self {
186        self.policy = policy;
187        self
188    }
189
190    /// Sets the positive cache TTL for successful responses.
191    pub fn ttl(mut self, ttl: Duration) -> Self {
192        self.policy = self.policy.with_ttl(ttl);
193        self
194    }
195
196    /// Sets the cache TTL for negative (4xx) responses.
197    pub fn negative_ttl(mut self, ttl: Duration) -> Self {
198        self.policy = self.policy.with_negative_ttl(ttl);
199        self
200    }
201
202    pub fn stale_while_revalidate(mut self, duration: Duration) -> Self {
203        self.policy = self.policy.with_stale_while_revalidate(duration);
204        self
205    }
206
207    pub fn refresh_before(mut self, duration: Duration) -> Self {
208        self.policy = self.policy.with_refresh_before(duration);
209        self
210    }
211
212    pub fn max_body_size(mut self, size: Option<usize>) -> Self {
213        self.policy = self.policy.with_max_body_size(size);
214        self
215    }
216
217    pub fn min_body_size(mut self, size: Option<usize>) -> Self {
218        self.policy = self.policy.with_min_body_size(size);
219        self
220    }
221
222    pub fn allow_streaming_bodies(mut self, allow: bool) -> Self {
223        self.policy = self.policy.with_allow_streaming_bodies(allow);
224        self
225    }
226
227    pub fn compression(mut self, config: CompressionConfig) -> Self {
228        self.policy = self.policy.with_compression(config);
229        self
230    }
231
232    pub fn respect_cache_control(mut self, enabled: bool) -> Self {
233        self.policy = self.policy.with_respect_cache_control(enabled);
234        self
235    }
236
237    pub fn statuses(mut self, statuses: impl IntoIterator<Item = u16>) -> Self {
238        self.policy = self.policy.with_statuses(statuses);
239        self
240    }
241
242    pub fn method_predicate<F>(mut self, predicate: F) -> Self
243    where
244        F: Fn(&Method) -> bool + Send + Sync + 'static,
245    {
246        self.policy = self.policy.with_method_predicate(predicate);
247        self
248    }
249
250    pub fn header_allowlist<I, S>(mut self, headers: I) -> Self
251    where
252        I: IntoIterator<Item = S>,
253        S: Into<String>,
254    {
255        self.policy = self.policy.with_header_allowlist(headers);
256        self
257    }
258
259    pub fn key_extractor(mut self, extractor: KeyExtractor) -> Self {
260        self.key_extractor = extractor;
261        self
262    }
263
264    /// Enables auto-refresh functionality with the provided configuration.
265    ///
266    /// When enabled, frequently accessed cache entries will be proactively
267    /// refreshed before they expire, reducing cache misses and latency.
268    pub fn auto_refresh(mut self, config: AutoRefreshConfig) -> Self {
269        self.auto_refresh_config = Some(config);
270        self
271    }
272
273    pub fn build(self) -> CacheLayer<B> {
274        let refresh_manager = self
275            .auto_refresh_config
276            .filter(|cfg| cfg.enabled)
277            .map(|cfg| Arc::new(RefreshManager::new(cfg)));
278
279        // Create chunk cache if enabled in streaming policy
280        let chunk_cache = if self.policy.streaming_policy().enable_chunk_cache {
281            Some(Arc::new(ChunkCache::new(
282                self.policy.streaming_policy().chunk_size,
283            )))
284        } else {
285            None
286        };
287
288        CacheLayer {
289            backend: self.backend,
290            policy: self.policy,
291            key_extractor: self.key_extractor,
292            locks: Arc::new(DashMap::new()),
293            refresh_manager,
294            chunk_cache,
295        }
296    }
297}
298
299impl CacheLayer<InMemoryBackend> {
300    /// Creates a cache layer backed by an in-memory [`InMemoryBackend`].
301    pub fn new_in_memory(max_capacity: u64) -> Self {
302        CacheLayerBuilder::new(InMemoryBackend::new(max_capacity)).build()
303    }
304}
305
306impl<B> CacheLayer<B>
307where
308    B: CacheBackend,
309{
310    /// Builds a cache layer with the default [`CachePolicy`].
311    pub fn new(backend: B) -> Self {
312        CacheLayerBuilder::new(backend).build()
313    }
314
315    /// Returns a builder for fine-grained control over the cache policy.
316    pub fn builder(backend: B) -> CacheLayerBuilder<B> {
317        CacheLayerBuilder::new(backend)
318    }
319
320    pub fn with_policy(mut self, policy: CachePolicy) -> Self {
321        self.policy = policy;
322        self
323    }
324
325    pub fn with_ttl(mut self, ttl: Duration) -> Self {
326        self.policy = self.policy.clone().with_ttl(ttl);
327        self
328    }
329
330    pub fn with_negative_ttl(mut self, ttl: Duration) -> Self {
331        self.policy = self.policy.clone().with_negative_ttl(ttl);
332        self
333    }
334
335    pub fn with_stale_while_revalidate(mut self, duration: Duration) -> Self {
336        self.policy = self.policy.clone().with_stale_while_revalidate(duration);
337        self
338    }
339
340    pub fn with_refresh_before(mut self, duration: Duration) -> Self {
341        self.policy = self.policy.clone().with_refresh_before(duration);
342        self
343    }
344
345    pub fn with_max_body_size(mut self, size: Option<usize>) -> Self {
346        self.policy = self.policy.clone().with_max_body_size(size);
347        self
348    }
349
350    pub fn with_min_body_size(mut self, size: Option<usize>) -> Self {
351        self.policy = self.policy.clone().with_min_body_size(size);
352        self
353    }
354
355    pub fn with_allow_streaming_bodies(mut self, allow: bool) -> Self {
356        self.policy = self.policy.clone().with_allow_streaming_bodies(allow);
357        self
358    }
359
360    pub fn with_compression(mut self, config: CompressionConfig) -> Self {
361        self.policy = self.policy.clone().with_compression(config);
362        self
363    }
364
365    pub fn with_respect_cache_control(mut self, enabled: bool) -> Self {
366        self.policy = self.policy.clone().with_respect_cache_control(enabled);
367        self
368    }
369
370    pub fn with_cache_statuses(mut self, statuses: impl IntoIterator<Item = u16>) -> Self {
371        self.policy = self.policy.clone().with_statuses(statuses);
372        self
373    }
374
375    pub fn with_method_predicate<F>(mut self, predicate: F) -> Self
376    where
377        F: Fn(&Method) -> bool + Send + Sync + 'static,
378    {
379        self.policy = self.policy.clone().with_method_predicate(predicate);
380        self
381    }
382
383    pub fn with_header_allowlist<I, S>(mut self, headers: I) -> Self
384    where
385        I: IntoIterator<Item = S>,
386        S: Into<String>,
387    {
388        self.policy = self.policy.clone().with_header_allowlist(headers);
389        self
390    }
391
392    pub fn with_key_extractor(mut self, extractor: KeyExtractor) -> Self {
393        self.key_extractor = extractor;
394        self
395    }
396
397    /// Manually initialize the auto-refresh manager with a service instance.
398    ///
399    /// This should be called after constructing the service to start the background
400    /// refresh task. This is only necessary if auto-refresh is enabled.
401    ///
402    /// # Example
403    ///
404    /// ```ignore
405    /// let layer = CacheLayer::builder(backend)
406    ///     .auto_refresh(config)
407    ///     .build();
408    ///
409    /// layer.init_auto_refresh(my_service.clone()).await?;
410    /// ```
411    pub async fn init_auto_refresh<S, ResBody>(&self, service: S) -> Result<(), String>
412    where
413        S: Service<Request<()>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
414        S::Future: Send + 'static,
415        S::Error: Into<BoxError> + Send,
416        ResBody: Body<Data = Bytes> + Send + 'static,
417        ResBody::Error: Into<BoxError> + Send,
418        B: Clone,
419    {
420        if let Some(ref manager) = self.refresh_manager {
421            let callback = Arc::new(CacheRefreshCallback::new(
422                service,
423                self.backend.clone(),
424                self.policy.clone(),
425                self.key_extractor.clone(),
426            ));
427            manager.start(callback).await
428        } else {
429            Ok(())
430        }
431    }
432}
433
434impl<S, B> Layer<S> for CacheLayer<B>
435where
436    B: CacheBackend,
437{
438    type Service = CacheService<S, B>;
439
440    fn layer(&self, inner: S) -> Self::Service {
441        CacheService {
442            inner,
443            backend: self.backend.clone(),
444            policy: self.policy.clone(),
445            key_extractor: self.key_extractor.clone(),
446            locks: self.locks.clone(),
447            refresh_manager: self.refresh_manager.clone(),
448            chunk_cache: self.chunk_cache.clone(),
449        }
450    }
451}
452
453impl<B> Drop for CacheLayer<B> {
454    fn drop(&mut self) {
455        // Trigger graceful shutdown of refresh manager
456        // We use tokio::spawn to avoid blocking in Drop
457        if let Some(manager) = &self.refresh_manager {
458            let manager = manager.clone();
459            // Best-effort shutdown - spawn detached task
460            // Note: We cannot guarantee execution in Drop, but we try our best
461            if let Ok(handle) = tokio::runtime::Handle::try_current() {
462                handle.spawn(async move {
463                    manager.shutdown().await;
464                });
465            }
466        }
467    }
468}
469
470#[derive(Clone)]
471pub struct CacheService<S, B> {
472    inner: S,
473    backend: B,
474    policy: CachePolicy,
475    key_extractor: KeyExtractor,
476    locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
477    refresh_manager: Option<Arc<RefreshManager>>,
478    chunk_cache: Option<Arc<ChunkCache>>,
479}
480
481/// Implementation of RefreshCallback for CacheService.
482struct CacheRefreshCallback<S, B> {
483    inner: S,
484    backend: B,
485    policy: CachePolicy,
486}
487
488impl<S, B> CacheRefreshCallback<S, B> {
489    fn new(inner: S, backend: B, policy: CachePolicy, _key_extractor: KeyExtractor) -> Self {
490        Self {
491            inner,
492            backend,
493            policy,
494        }
495    }
496}
497
498impl<S, B, ResBody> RefreshCallback for CacheRefreshCallback<S, B>
499where
500    S: Service<Request<()>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
501    S::Future: Send + 'static,
502    S::Error: Into<BoxError> + Send,
503    ResBody: Body<Data = Bytes> + Send + 'static,
504    ResBody::Error: Into<BoxError> + Send,
505    B: CacheBackend,
506{
507    fn refresh(&self, key: String, metadata: RefreshMetadata) -> crate::refresh::RefreshFuture {
508        let backend = self.backend.clone();
509        let policy = self.policy.clone();
510        let inner = self.inner.clone();
511
512        Box::pin(async move {
513            #[cfg(feature = "tracing")]
514            tracing::debug!(key = %key, uri = %metadata.uri, "Auto-refresh triggered");
515
516            // Reconstruct the request
517            let request = match metadata.try_into_request() {
518                Some(req) => req,
519                None => {
520                    #[cfg(feature = "tracing")]
521                    tracing::warn!(key = %key, "Failed to reconstruct request for auto-refresh");
522                    return Err("Failed to reconstruct request".into());
523                }
524            };
525
526            // Call the inner service
527            let service = inner;
528            let response = match service.oneshot(request).await {
529                Ok(resp) => resp,
530                Err(_err) => {
531                    #[cfg(feature = "tracing")]
532                    tracing::error!(key = %key, "Service error during auto-refresh");
533                    return Err("Service error during refresh".into());
534                }
535            };
536
537            let (parts, body) = response.into_parts();
538
539            // Collect the body
540            let collected = match BodyExt::collect(body).await {
541                Ok(c) => c,
542                Err(_err) => {
543                    #[cfg(feature = "tracing")]
544                    tracing::error!(key = %key, "Body collection error during auto-refresh");
545                    return Err("Body collection error".into());
546                }
547            };
548
549            let cache_bytes = collected.to_bytes();
550
551            // Check if we should cache this response
552            let body_too_large = policy
553                .max_body_size()
554                .is_some_and(|max| cache_bytes.len() > max);
555            let body_too_small = policy
556                .min_body_size()
557                .is_some_and(|min| cache_bytes.len() < min);
558
559            if body_too_large || body_too_small {
560                return Ok(()); // Successfully refreshed but not stored
561            }
562
563            // Store the refreshed entry
564            if let Some(ttl) = policy.ttl_for(parts.status) {
565                if !ttl.is_zero() {
566                    let stale_for = policy.stale_while_revalidate();
567                    let headers_to_cache = policy.headers_to_cache(&parts.headers);
568                    let (compressed_bytes, _compressed) =
569                        maybe_compress(cache_bytes, policy.compression());
570
571                    let entry = CacheEntry::new(
572                        parts.status,
573                        parts.version,
574                        headers_to_cache,
575                        compressed_bytes,
576                    );
577
578                    if let Err(_err) = backend.set(key.clone(), entry, ttl, stale_for).await {
579                        #[cfg(feature = "tracing")]
580                        tracing::error!(key = %key, "Failed to store refreshed entry");
581                        return Err("Failed to store entry".into());
582                    }
583                }
584            }
585
586            Ok(())
587        })
588    }
589}
590
591// Note: We cannot easily initialize the refresh manager from within the service
592// because the service may have different type parameters than required by the callback.
593// Instead, users who want to use auto-refresh should ensure the service is called at least once,
594// or manually initialize the refresh functionality if needed.
595
596impl<S, B, ReqBody, ResBody> Service<Request<ReqBody>> for CacheService<S, B>
597where
598    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
599    S::Future: Send + 'static,
600    S::Error: Into<BoxError> + Send,
601    ReqBody: Send + 'static,
602    ResBody: Body<Data = Bytes> + Send + Sync + 'static,
603    ResBody::Error: Into<BoxError> + Send,
604    B: CacheBackend,
605{
606    type Response = Response<SyncBoxBody>;
607    type Error = BoxError;
608    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
609
610    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
611        self.inner.poll_ready(cx).map_err(Into::into)
612    }
613
614    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
615        let method = req.method().clone();
616        let uri = req.uri().clone();
617        let should_cache_method = self.policy.should_cache_method(&method);
618        let request_bypass =
619            self.policy.respect_cache_control() && cache_control_disallows(req.headers());
620        let key = if should_cache_method && !request_bypass {
621            self.key_extractor.extract(&method, &uri)
622        } else {
623            None
624        };
625
626        let backend = self.backend.clone();
627        let policy = self.policy.clone();
628        let locks = self.locks.clone();
629        let inner = self.inner.clone();
630        let stale_window = policy.stale_while_revalidate();
631        let refresh_before = policy.refresh_before();
632        let refresh_manager = self.refresh_manager.clone();
633        let chunk_cache = self.chunk_cache.clone();
634
635        // Check for range request early
636        let range_request = parse_range_header(req.headers());
637
638        // Prepare refresh metadata if auto-refresh is enabled
639        let refresh_metadata = if refresh_manager.is_some() && key.is_some() {
640            Some(RefreshMetadata::from_request(&req))
641        } else {
642            None
643        };
644
645        Box::pin(async move {
646            #[cfg(feature = "tracing")]
647            tracing::debug!(method = %method, uri = %uri, "cache_call");
648
649            // Try to serve from chunk cache if this is a range request
650            if let (Some(range_req), Some(ref chunk_cache), Some(ref key_ref)) =
651                (range_request.as_ref(), &chunk_cache, &key)
652            {
653                if let Some(entry) = chunk_cache.get(key_ref) {
654                    // Check if range is satisfiable and normalize it
655                    if let Some(normalized) = range_req.normalize(entry.metadata.total_size) {
656                        let end = normalized.end.unwrap_or(entry.metadata.total_size - 1);
657
658                        // Try to get range from chunks
659                        if let Some(range_data) = entry.get_range(normalized.start, end) {
660                            #[cfg(feature = "metrics")]
661                            counter!("tower_http_cache.chunk_cache_hit").increment(1);
662
663                            #[cfg(feature = "tracing")]
664                            tracing::debug!(
665                                key = %key_ref,
666                                start = normalized.start,
667                                end = end,
668                                "chunk_cache_hit"
669                            );
670
671                            // Build 206 Partial Content response
672                            let mut response = Response::builder()
673                                .status(http::StatusCode::PARTIAL_CONTENT)
674                                .body(SyncBoxBody::new(
675                                    Full::from(range_data).map_err(Into::into).boxed(),
676                                ))
677                                .unwrap();
678
679                            // Copy headers from metadata
680                            for (name, value) in &entry.metadata.headers {
681                                if let (Ok(header_name), Ok(header_value)) = (
682                                    http::header::HeaderName::from_bytes(name.as_bytes()),
683                                    http::header::HeaderValue::from_bytes(value),
684                                ) {
685                                    response.headers_mut().insert(header_name, header_value);
686                                }
687                            }
688
689                            // Add Content-Range header
690                            let content_range = format!(
691                                "bytes {}-{}/{}",
692                                normalized.start, end, entry.metadata.total_size
693                            );
694                            response.headers_mut().insert(
695                                http::header::CONTENT_RANGE,
696                                http::header::HeaderValue::from_str(&content_range).unwrap(),
697                            );
698
699                            // Add Content-Length
700                            let content_length = (end - normalized.start + 1).to_string();
701                            response.headers_mut().insert(
702                                http::header::CONTENT_LENGTH,
703                                http::header::HeaderValue::from_str(&content_length).unwrap(),
704                            );
705
706                            return Ok(response);
707                        }
708                    }
709                }
710
711                #[cfg(feature = "metrics")]
712                counter!("tower_http_cache.chunk_cache_miss").increment(1);
713            }
714
715            let mut stale_entry: Option<CacheEntry> = None;
716            if let Some(ref key_ref) = key {
717                if let Ok(Some(hit)) = backend.get(key_ref).await {
718                    match classify_hit(hit, stale_window, refresh_before) {
719                        HitState::Fresh(entry) => {
720                            #[cfg(feature = "metrics")]
721                            counter!("tower_http_cache.hit").increment(1);
722
723                            // Record hit for auto-refresh tracking
724                            if let Some(ref manager) = refresh_manager {
725                                manager.tracker().record_hit(key_ref);
726                            }
727
728                            return Ok(entry.into_response());
729                        }
730                        HitState::Stale(entry) => {
731                            #[cfg(feature = "metrics")]
732                            counter!("tower_http_cache.stale_hit").increment(1);
733
734                            // Record hit for auto-refresh tracking
735                            if let Some(ref manager) = refresh_manager {
736                                manager.tracker().record_hit(key_ref);
737                            }
738
739                            stale_entry = Some(entry);
740                        }
741                        HitState::Expired => {}
742                    }
743                }
744            }
745
746            let mut primary_guard: Option<StampedeGuard> = None;
747            if let Some(ref key_ref) = key {
748                match StampedeGuard::acquire_handle(locks.clone(), key_ref.clone()).await {
749                    StampedeHandle::Primary(guard) => {
750                        primary_guard = Some(guard);
751                    }
752                    StampedeHandle::Secondary(lock) => {
753                        if let Some(entry) = stale_entry.clone() {
754                            #[cfg(feature = "metrics")]
755                            counter!("tower_http_cache.stale_served").increment(1);
756                            return Ok(entry.into_response());
757                        }
758
759                        let secondary_guard = lock.lock_owned().await;
760                        drop(secondary_guard);
761
762                        if let Ok(Some(hit)) = backend.get(key_ref).await {
763                            match classify_hit(hit, stale_window, refresh_before) {
764                                HitState::Fresh(entry) => {
765                                    #[cfg(feature = "metrics")]
766                                    counter!("tower_http_cache.hit_after_wait").increment(1);
767                                    return Ok(entry.into_response());
768                                }
769                                HitState::Stale(entry) => {
770                                    #[cfg(feature = "metrics")]
771                                    counter!("tower_http_cache.stale_served").increment(1);
772                                    return Ok(entry.into_response());
773                                }
774                                HitState::Expired => {}
775                            }
776                        }
777
778                        if let StampedeHandle::Primary(guard) =
779                            StampedeGuard::acquire_handle(locks.clone(), key_ref.clone()).await
780                        {
781                            primary_guard = Some(guard);
782                        }
783                    }
784                }
785            }
786
787            #[cfg(feature = "metrics")]
788            counter!("tower_http_cache.miss").increment(1);
789
790            #[cfg(feature = "metrics")]
791            let start = std::time::Instant::now();
792            let service = inner;
793            let response = service.oneshot(req).await.map_err(|err| err.into())?;
794            #[cfg(feature = "metrics")]
795            histogram!("tower_http_cache.backend_latency").record(start.elapsed().as_secs_f64());
796
797            let (parts, body) = response.into_parts();
798
799            // NEW: Early streaming decision
800            let content_type = parts
801                .headers
802                .get(http::header::CONTENT_TYPE)
803                .and_then(|v| v.to_str().ok());
804
805            let content_length = parts
806                .headers
807                .get(http::header::CONTENT_LENGTH)
808                .and_then(|v| v.to_str().ok())
809                .and_then(|v| v.parse::<u64>().ok());
810
811            let size_hint = body.size_hint();
812
813            let streaming_decision = should_stream(
814                policy.streaming_policy(),
815                &size_hint,
816                content_type,
817                content_length,
818            );
819
820            // Check for range requests
821            let is_range_request = parse_range_header(&parts.headers).is_some();
822            let is_partial_response = is_partial_content(parts.status);
823            let range_handling = policy.streaming_policy().range_handling;
824
825            // Handle range requests according to policy
826            if (is_range_request || is_partial_response)
827                && range_handling == RangeHandling::PassThrough
828            {
829                #[cfg(feature = "metrics")]
830                counter!("tower_http_cache.range_request_passthrough").increment(1);
831
832                #[cfg(feature = "tracing")]
833                tracing::debug!(
834                    method = %method,
835                    uri = %uri,
836                    is_range = is_range_request,
837                    is_partial = is_partial_response,
838                    "range_request_passthrough"
839                );
840
841                // Stream through without buffering for range requests
842                let boxed_body = SyncBoxBody::new(body.map_err(Into::into).boxed());
843                drop(primary_guard);
844                return Ok(Response::from_parts(parts, boxed_body));
845            }
846
847            // Check if we should skip caching and pass through
848            match streaming_decision {
849                StreamingDecision::SkipCache | StreamingDecision::StreamThrough => {
850                    // TRUE STREAMING: Pass through without buffering!
851                    #[cfg(feature = "metrics")]
852                    counter!("tower_http_cache.streaming_passthrough").increment(1);
853
854                    #[cfg(feature = "tracing")]
855                    tracing::debug!(
856                        method = %method,
857                        uri = %uri,
858                        decision = ?streaming_decision,
859                        content_type = ?content_type,
860                        size = ?extract_size_info(&size_hint, content_length),
861                        "streaming_passthrough"
862                    );
863
864                    // Box the body and stream it through without collecting
865                    let boxed_body = SyncBoxBody::new(body.map_err(Into::into).boxed());
866                    drop(primary_guard);
867                    return Ok(Response::from_parts(parts, boxed_body));
868                }
869                _ => {}
870            }
871
872            let streaming = body.size_hint().upper().is_none();
873            if streaming && !policy.allow_streaming_bodies() {
874                #[cfg(feature = "metrics")]
875                counter!("tower_http_cache.streaming_skip").increment(1);
876            }
877
878            let collected = BodyExt::collect(body).await.map_err(|err| err.into())?;
879            let cache_bytes = collected.to_bytes();
880            let response_bytes = cache_bytes.clone();
881
882            // Populate chunk cache for large files if enabled
883            if let (Some(ref chunk_cache), Some(ref key_ref)) = (&chunk_cache, &key) {
884                let streaming_policy = policy.streaming_policy();
885
886                if streaming_policy.enable_chunk_cache
887                    && cache_bytes.len() as u64 >= streaming_policy.min_chunk_file_size
888                    && parts.status.is_success()
889                {
890                    #[cfg(feature = "tracing")]
891                    tracing::debug!(
892                        key = %key_ref,
893                        size = cache_bytes.len(),
894                        chunk_size = streaming_policy.chunk_size,
895                        "populating_chunk_cache"
896                    );
897
898                    // Create chunk metadata
899                    let metadata = ChunkMetadata {
900                        total_size: cache_bytes.len() as u64,
901                        content_type: parts
902                            .headers
903                            .get(http::header::CONTENT_TYPE)
904                            .and_then(|v| v.to_str().ok())
905                            .unwrap_or("application/octet-stream")
906                            .to_string(),
907                        etag: parts
908                            .headers
909                            .get(http::header::ETAG)
910                            .and_then(|v| v.to_str().ok())
911                            .map(|s| s.to_string()),
912                        last_modified: parts
913                            .headers
914                            .get(http::header::LAST_MODIFIED)
915                            .and_then(|v| v.to_str().ok())
916                            .map(|s| s.to_string()),
917                        status: parts.status,
918                        version: parts.version,
919                        headers: policy.headers_to_cache(&parts.headers),
920                    };
921
922                    // Get or create chunked entry
923                    let entry = chunk_cache.get_or_create(key_ref.clone(), metadata);
924
925                    // Split into chunks and store
926                    let chunk_size = streaming_policy.chunk_size;
927                    let mut offset = 0;
928                    let mut chunk_index = 0;
929
930                    while offset < cache_bytes.len() {
931                        let end = std::cmp::min(offset + chunk_size, cache_bytes.len());
932                        let chunk = cache_bytes.slice(offset..end);
933                        entry.add_chunk(chunk_index, chunk);
934
935                        offset = end;
936                        chunk_index += 1;
937                    }
938
939                    #[cfg(feature = "metrics")]
940                    counter!("tower_http_cache.chunk_cache_stored").increment(1);
941
942                    #[cfg(feature = "tracing")]
943                    tracing::debug!(
944                        key = %key_ref,
945                        chunks = chunk_index,
946                        "chunk_cache_populated"
947                    );
948                }
949            }
950
951            let cache_control_block =
952                policy.respect_cache_control() && cache_control_disallows(&parts.headers);
953            let body_too_large = policy
954                .max_body_size()
955                .is_some_and(|max| cache_bytes.len() > max);
956            let body_too_small = policy
957                .min_body_size()
958                .is_some_and(|min| cache_bytes.len() < min);
959
960            let should_store = key.is_some()
961                && !cache_control_block
962                && !body_too_large
963                && !body_too_small
964                && (policy.allow_streaming_bodies() || !streaming);
965
966            let headers_to_cache = if should_store {
967                Some(policy.headers_to_cache(&parts.headers))
968            } else {
969                None
970            };
971
972            let version = parts.version;
973            let status = parts.status;
974
975            if should_store {
976                if let Some(key_ref) = &key {
977                    if let Some(ttl) = policy.ttl_for(status) {
978                        if !ttl.is_zero() {
979                            let stale_for = policy.stale_while_revalidate();
980                            let (compressed_bytes, compressed) =
981                                maybe_compress(cache_bytes.clone(), policy.compression());
982                            if compressed {
983                                #[cfg(feature = "metrics")]
984                                counter!("tower_http_cache.compressed").increment(1);
985                            }
986                            let entry = CacheEntry::new(
987                                status,
988                                version,
989                                headers_to_cache.unwrap(),
990                                compressed_bytes,
991                            );
992                            if backend
993                                .set(key_ref.clone(), entry, ttl, stale_for)
994                                .await
995                                .is_err()
996                            {
997                                #[cfg(feature = "metrics")]
998                                counter!("tower_http_cache.store_error").increment(1);
999                            } else {
1000                                #[cfg(feature = "metrics")]
1001                                counter!("tower_http_cache.store").increment(1);
1002
1003                                // Store refresh metadata if auto-refresh is enabled
1004                                if let (Some(ref manager), Some(metadata)) =
1005                                    (&refresh_manager, refresh_metadata)
1006                                {
1007                                    manager.store_metadata(key_ref.clone(), metadata);
1008                                }
1009                            }
1010                        }
1011                    }
1012                }
1013            } else {
1014                #[cfg(feature = "metrics")]
1015                counter!("tower_http_cache.store_skipped").increment(1);
1016            }
1017
1018            drop(primary_guard);
1019
1020            // Box the response body
1021            let full_body = Full::from(response_bytes);
1022            let boxed_body = SyncBoxBody::new(full_body.map_err(Into::into).boxed());
1023            Ok(Response::from_parts(parts, boxed_body))
1024        })
1025    }
1026}
1027
1028struct StampedeGuard {
1029    key: String,
1030    locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
1031    lock: Arc<Mutex<()>>,
1032    _guard: OwnedMutexGuard<()>,
1033}
1034
1035enum StampedeHandle {
1036    Primary(StampedeGuard),
1037    Secondary(Arc<Mutex<()>>),
1038}
1039
1040impl StampedeGuard {
1041    async fn acquire_handle(
1042        locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
1043        key: String,
1044    ) -> StampedeHandle {
1045        let handle = match locks.entry(key.clone()) {
1046            Entry::Occupied(entry) => StampedeHandle::Secondary(entry.get().clone()),
1047            Entry::Vacant(entry) => {
1048                let lock = Arc::new(Mutex::new(()));
1049                entry.insert(lock.clone());
1050                let guard = lock.clone().lock_owned().await;
1051                let locks_clone = locks.clone();
1052                StampedeHandle::Primary(StampedeGuard {
1053                    key,
1054                    locks: locks_clone,
1055                    lock,
1056                    _guard: guard,
1057                })
1058            }
1059        };
1060        handle
1061    }
1062}
1063
1064impl Drop for StampedeGuard {
1065    fn drop(&mut self) {
1066        if let Some(current) = self.locks.get(&self.key) {
1067            let should_remove = Arc::ptr_eq(&self.lock, current.value());
1068            drop(current);
1069            if should_remove {
1070                self.locks.remove(&self.key);
1071            }
1072        }
1073    }
1074}
1075
1076fn classify_hit(hit: CacheRead, stale_window: Duration, refresh_before: Duration) -> HitState {
1077    let now = SystemTime::now();
1078    let CacheRead {
1079        entry,
1080        expires_at,
1081        stale_until,
1082    } = hit;
1083
1084    if let Some(expires_at) = expires_at {
1085        if expires_at > now {
1086            if refresh_before > Duration::ZERO {
1087                if let Some(threshold) = expires_at.checked_sub(refresh_before) {
1088                    if now >= threshold {
1089                        return HitState::Stale(entry);
1090                    }
1091                } else {
1092                    return HitState::Stale(entry);
1093                }
1094            }
1095            return HitState::Fresh(entry);
1096        }
1097    }
1098
1099    if stale_window > Duration::ZERO {
1100        if let Some(stale_until) = stale_until {
1101            if stale_until > now {
1102                return HitState::Stale(entry);
1103            }
1104        }
1105    }
1106
1107    HitState::Expired
1108}
1109
1110#[derive(Debug)]
1111enum HitState {
1112    Fresh(CacheEntry),
1113    Stale(CacheEntry),
1114    Expired,
1115}
1116
1117fn cache_control_disallows(headers: &HeaderMap) -> bool {
1118    headers
1119        .get_all(CACHE_CONTROL)
1120        .iter()
1121        .filter_map(|value| value.to_str().ok())
1122        .flat_map(|value| value.split(','))
1123        .map(|token| token.trim().to_ascii_lowercase())
1124        .any(|token| matches!(token.as_str(), "no-store" | "no-cache" | "private"))
1125        || headers
1126            .get(PRAGMA)
1127            .and_then(|value| value.to_str().ok())
1128            .map(|value| value.to_ascii_lowercase().contains("no-cache"))
1129            .unwrap_or(false)
1130}
1131
1132#[cfg(feature = "compression")]
1133fn maybe_compress(bytes: Bytes, config: CompressionConfig) -> (Bytes, bool) {
1134    use flate2::{write::GzEncoder, Compression};
1135    use std::io::Write;
1136
1137    match config.strategy {
1138        CompressionStrategy::None => (bytes, false),
1139        CompressionStrategy::Gzip => {
1140            if bytes.len() < config.min_size {
1141                return (bytes, false);
1142            }
1143            let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
1144            if encoder.write_all(&bytes).is_err() {
1145                return (bytes, false);
1146            }
1147            match encoder.finish() {
1148                Ok(data) => (Bytes::from(data), true),
1149                Err(_) => (bytes, false),
1150            }
1151        }
1152    }
1153}
1154
1155#[cfg(not(feature = "compression"))]
1156fn maybe_compress(bytes: Bytes, _config: CompressionConfig) -> (Bytes, bool) {
1157    let _ = _config;
1158    (bytes, false)
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163    use super::*;
1164    use crate::backend::CacheEntry;
1165    use bytes::Bytes;
1166    use http::{HeaderValue, StatusCode, Version};
1167    use tokio::task::yield_now;
1168
1169    fn mock_entry() -> CacheEntry {
1170        CacheEntry::new(
1171            StatusCode::OK,
1172            Version::HTTP_11,
1173            Vec::new(),
1174            Bytes::from_static(b"body"),
1175        )
1176    }
1177
1178    #[test]
1179    fn classify_hit_marks_entry_fresh_when_not_near_expiry() {
1180        let now = SystemTime::now();
1181        let hit = CacheRead {
1182            entry: mock_entry(),
1183            expires_at: Some(now + Duration::from_secs(10)),
1184            stale_until: Some(now + Duration::from_secs(20)),
1185        };
1186
1187        match classify_hit(hit, Duration::from_secs(5), Duration::from_secs(1)) {
1188            HitState::Fresh(_) => {}
1189            other => panic!("expected fresh entry, got {:?}", other),
1190        }
1191    }
1192
1193    #[test]
1194    fn classify_hit_marks_entry_stale_when_within_refresh_window() {
1195        let now = SystemTime::now();
1196        let hit = CacheRead {
1197            entry: mock_entry(),
1198            expires_at: Some(now + Duration::from_secs(2)),
1199            stale_until: Some(now + Duration::from_secs(10)),
1200        };
1201
1202        match classify_hit(hit, Duration::from_secs(5), Duration::from_secs(5)) {
1203            HitState::Stale(_) => {}
1204            other => panic!("expected stale entry, got {:?}", other),
1205        }
1206    }
1207
1208    #[test]
1209    fn classify_hit_marks_entry_stale_when_within_stale_window() {
1210        let now = SystemTime::now();
1211        let hit = CacheRead {
1212            entry: mock_entry(),
1213            expires_at: Some(now - Duration::from_secs(1)),
1214            stale_until: Some(now + Duration::from_secs(1)),
1215        };
1216
1217        match classify_hit(hit, Duration::from_secs(2), Duration::from_secs(0)) {
1218            HitState::Stale(_) => {}
1219            other => panic!("expected stale entry, got {:?}", other),
1220        }
1221    }
1222
1223    #[test]
1224    fn classify_hit_marks_entry_expired_after_stale_window() {
1225        let now = SystemTime::now();
1226        let hit = CacheRead {
1227            entry: mock_entry(),
1228            expires_at: Some(now - Duration::from_secs(5)),
1229            stale_until: Some(now - Duration::from_secs(1)),
1230        };
1231
1232        match classify_hit(hit, Duration::from_secs(5), Duration::from_secs(0)) {
1233            HitState::Expired => {}
1234            other => panic!("expected expired entry, got {:?}", other),
1235        }
1236    }
1237
1238    #[test]
1239    fn cache_control_disallows_detects_no_cache_directives() {
1240        let mut headers = HeaderMap::new();
1241        headers.insert(
1242            CACHE_CONTROL,
1243            HeaderValue::from_static("max-age=0, no-cache"),
1244        );
1245        assert!(cache_control_disallows(&headers));
1246
1247        let mut pragma_only = HeaderMap::new();
1248        pragma_only.insert(PRAGMA, HeaderValue::from_static("no-cache"));
1249        assert!(cache_control_disallows(&pragma_only));
1250    }
1251
1252    #[tokio::test]
1253    async fn stampede_guard_drop_removes_lock_entry() {
1254        let locks = Arc::new(DashMap::new());
1255        let key = "key".to_string();
1256
1257        match StampedeGuard::acquire_handle(locks.clone(), key.clone()).await {
1258            StampedeHandle::Primary(guard) => {
1259                assert!(locks.get(&key).is_some());
1260                drop(guard);
1261                yield_now().await;
1262                assert!(locks.get(&key).is_none());
1263            }
1264            StampedeHandle::Secondary(_) => panic!("expected primary guard"),
1265        }
1266    }
1267
1268    #[test]
1269    fn cache_service_implements_clone() {
1270        use crate::backend::memory::InMemoryBackend;
1271        use tower::service_fn;
1272
1273        // Compile-time check that CacheService implements Clone
1274        fn assert_clone<T: Clone>(_: &T) {}
1275
1276        let backend = InMemoryBackend::new(100);
1277        let layer = CacheLayer::new(backend);
1278        let service = layer.layer(service_fn(|_req: http::Request<()>| async {
1279            Ok::<_, std::convert::Infallible>(http::Response::new(()))
1280        }));
1281
1282        // This will fail to compile if CacheService doesn't implement Clone
1283        assert_clone(&service);
1284    }
1285}