tower_http_cache/backend/
mod.rs

1//! Storage backends for the cache layer.
2//!
3//! The cache layer requires a [`CacheBackend`] implementation to persist
4//! cached responses. This module ships with:
5//! - [`memory::InMemoryBackend`] — a fast, process-local cache backed by [`moka`].
6//! - `redis::RedisBackend` *(optional)* — a distributed cache when the
7//!   `redis-backend` crate feature is enabled.
8//! - `memcached::MemcachedBackend` *(optional)* — a distributed cache when the
9//!   `memcached-backend` crate feature is enabled.
10//!
11//! Backends are responsible for answering cache lookups, storing entries,
12//! and enforcing per-entry stale windows.
13
14#[cfg(feature = "memcached-backend")]
15pub mod memcached;
16pub mod memory;
17pub mod multi_tier;
18#[cfg(feature = "redis-backend")]
19pub mod redis;
20
21use async_trait::async_trait;
22use bytes::Bytes;
23use http::{HeaderName, HeaderValue, Response, StatusCode, Version};
24use std::time::{Duration, SystemTime};
25
26use crate::error::CacheError;
27use crate::layer::SyncBoxBody;
28
29/// Cached response payload captured by the cache layer.
30#[derive(Debug, Clone)]
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
32pub struct CacheEntry {
33    #[cfg_attr(feature = "serde", serde(with = "status_code_serde"))]
34    pub status: StatusCode,
35    #[cfg_attr(feature = "serde", serde(with = "version_serde"))]
36    pub version: Version,
37    pub headers: Vec<(String, Vec<u8>)>,
38    #[cfg_attr(feature = "serde", serde(with = "bytes_serde"))]
39    pub body: Bytes,
40    pub tags: Option<Vec<String>>,
41}
42
43// Custom serde helpers for http types
44#[cfg(feature = "serde")]
45mod status_code_serde {
46    use http::StatusCode;
47    use serde::{Deserialize, Deserializer, Serialize, Serializer};
48
49    pub fn serialize<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
50    where
51        S: Serializer,
52    {
53        status.as_u16().serialize(serializer)
54    }
55
56    pub fn deserialize<'de, D>(deserializer: D) -> Result<StatusCode, D::Error>
57    where
58        D: Deserializer<'de>,
59    {
60        let code = u16::deserialize(deserializer)?;
61        StatusCode::from_u16(code).map_err(serde::de::Error::custom)
62    }
63}
64
65#[cfg(feature = "serde")]
66mod version_serde {
67    use http::Version;
68    use serde::{Deserialize, Deserializer, Serialize, Serializer};
69
70    pub fn serialize<S>(version: &Version, serializer: S) -> Result<S::Ok, S::Error>
71    where
72        S: Serializer,
73    {
74        let v = match *version {
75            Version::HTTP_09 => 0,
76            Version::HTTP_10 => 1,
77            Version::HTTP_11 => 2,
78            Version::HTTP_2 => 3,
79            Version::HTTP_3 => 4,
80            _ => 5,
81        };
82        v.serialize(serializer)
83    }
84
85    pub fn deserialize<'de, D>(deserializer: D) -> Result<Version, D::Error>
86    where
87        D: Deserializer<'de>,
88    {
89        let v = u8::deserialize(deserializer)?;
90        Ok(match v {
91            0 => Version::HTTP_09,
92            1 => Version::HTTP_10,
93            2 => Version::HTTP_11,
94            3 => Version::HTTP_2,
95            4 => Version::HTTP_3,
96            _ => Version::HTTP_11, // Default fallback
97        })
98    }
99}
100
101#[cfg(feature = "serde")]
102mod bytes_serde {
103    use bytes::Bytes;
104    use serde::{Deserialize, Deserializer, Serializer};
105
106    pub fn serialize<S>(bytes: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
107    where
108        S: Serializer,
109    {
110        serializer.serialize_bytes(bytes)
111    }
112
113    pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
114    where
115        D: Deserializer<'de>,
116    {
117        let vec = Vec::<u8>::deserialize(deserializer)?;
118        Ok(Bytes::from(vec))
119    }
120}
121
122impl CacheEntry {
123    /// Creates a new cached response entry.
124    ///
125    /// The entry captures the response status, HTTP version, a serialized
126    /// subset of headers, and the collected response body.
127    pub fn new(
128        status: StatusCode,
129        version: Version,
130        headers: Vec<(String, Vec<u8>)>,
131        body: Bytes,
132    ) -> Self {
133        Self {
134            status,
135            version,
136            headers,
137            body,
138            tags: None,
139        }
140    }
141
142    /// Creates a new cached response entry with tags.
143    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
144        self.tags = Some(tags);
145        self
146    }
147
148    /// Converts the entry back into an `http::Response`.
149    pub fn into_response(self) -> Response<SyncBoxBody> {
150        use http_body_util::BodyExt;
151
152        let full_body = http_body_util::Full::from(self.body);
153        let boxed_body = full_body
154            .map_err(|never| -> Box<dyn std::error::Error + Send + Sync> { match never {} })
155            .boxed();
156
157        let mut response = Response::new(SyncBoxBody::new(boxed_body));
158        *response.status_mut() = self.status;
159        *response.version_mut() = self.version;
160
161        let headers = response.headers_mut();
162        headers.clear();
163        for (name, value) in self.headers {
164            if let (Ok(name), Ok(value)) = (
165                HeaderName::from_bytes(name.as_bytes()),
166                HeaderValue::from_bytes(&value),
167            ) {
168                headers.append(name, value);
169            }
170        }
171
172        response
173    }
174}
175
176#[derive(Debug, Clone)]
177pub struct CacheRead {
178    /// Cached entry together with timing metadata.
179    pub entry: CacheEntry,
180    pub expires_at: Option<SystemTime>,
181    pub stale_until: Option<SystemTime>,
182}
183
184#[async_trait]
185pub trait CacheBackend: Send + Sync + Clone + 'static {
186    /// Fetches a cached entry by key.
187    ///
188    /// Returns `Ok(None)` when the backend does not have a value or the
189    /// entry has expired.
190    async fn get(&self, key: &str) -> Result<Option<CacheRead>, CacheError>;
191
192    /// Stores an entry with a time-to-live and additional stale window.
193    async fn set(
194        &self,
195        key: String,
196        entry: CacheEntry,
197        ttl: Duration,
198        stale_for: Duration,
199    ) -> Result<(), CacheError>;
200
201    /// Invalidates the cache entry for `key`, if present.
202    async fn invalidate(&self, key: &str) -> Result<(), CacheError>;
203
204    /// Retrieves all cache keys associated with a tag.
205    ///
206    /// Returns an empty vector if tags are not supported by this backend.
207    async fn get_keys_by_tag(&self, _tag: &str) -> Result<Vec<String>, CacheError> {
208        Ok(Vec::new())
209    }
210
211    /// Invalidates all cache entries associated with a tag.
212    ///
213    /// Returns the number of entries invalidated.
214    async fn invalidate_by_tag(&self, tag: &str) -> Result<usize, CacheError> {
215        let keys = self.get_keys_by_tag(tag).await?;
216        let count = keys.len();
217        for key in keys {
218            let _ = self.invalidate(&key).await;
219        }
220        Ok(count)
221    }
222
223    /// Invalidates all cache entries associated with multiple tags.
224    ///
225    /// Returns the total number of entries invalidated (may include duplicates).
226    async fn invalidate_by_tags(&self, tags: &[String]) -> Result<usize, CacheError> {
227        let mut total = 0;
228        for tag in tags {
229            total += self.invalidate_by_tag(tag).await?;
230        }
231        Ok(total)
232    }
233
234    /// Lists all currently indexed tags.
235    ///
236    /// Returns an empty vector if tags are not supported by this backend.
237    async fn list_tags(&self) -> Result<Vec<String>, CacheError> {
238        Ok(Vec::new())
239    }
240}