Skip to main content

spider_lib/middlewares/
http_cache.rs

1//! HTTP Cache Middleware for caching web responses.
2//!
3//! This module provides the `HttpCacheMiddleware`, which intercepts HTTP requests and
4//! responses to implement a caching mechanism. It stores successful HTTP responses (e.g., 200 OK)
5//! to a local directory, and for subsequent identical requests, it serves the cached response
6//! instead of making a new network request. This can significantly reduce network traffic,
7//! improve crawling speed, and enable offline processing or replay of crawls.
8//!
9//! The cache uses request fingerprints to identify unique requests and associates them
10//! with their corresponding cached responses. Responses are serialized and deserialized
11//! using `bincode`.
12
13use async_trait::async_trait;
14use reqwest::StatusCode;
15use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
16use std::path::PathBuf;
17use tokio::fs;
18use tracing::{debug, info, warn};
19
20use crate::error::SpiderError;
21use crate::middleware::{Middleware, MiddlewareAction};
22use crate::request::Request;
23use crate::response::Response;
24use crate::utils;
25use bytes::Bytes;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27use url::Url;
28
29fn serialize_headermap<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
30where
31    S: Serializer,
32{
33    let mut map = std::collections::HashMap::<String, String>::new();
34    for (name, value) in headers.iter() {
35        map.insert(
36            name.to_string(),
37            value.to_str().unwrap_or_default().to_string(),
38        );
39    }
40    map.serialize(serializer)
41}
42
43fn deserialize_headermap<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
44where
45    D: Deserializer<'de>,
46{
47    let map = std::collections::HashMap::<String, String>::deserialize(deserializer)?;
48    let mut headers = HeaderMap::new();
49    for (name, value) in map {
50        if let (Ok(header_name), Ok(header_value)) =
51            (name.parse::<HeaderName>(), value.parse::<HeaderValue>())
52        {
53            headers.insert(header_name, header_value);
54        } else {
55            warn!("Failed to parse header: {} = {}", name, value);
56        }
57    }
58    Ok(headers)
59}
60
61fn serialize_statuscode<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
62where
63    S: Serializer,
64{
65    status.as_u16().serialize(serializer)
66}
67
68fn deserialize_statuscode<'de, D>(deserializer: D) -> Result<StatusCode, D::Error>
69where
70    D: Deserializer<'de>,
71{
72    let status_u16 = u16::deserialize(deserializer)?;
73    StatusCode::from_u16(status_u16).map_err(serde::de::Error::custom)
74}
75
76fn serialize_url<S>(url: &Url, serializer: S) -> Result<S::Ok, S::Error>
77where
78    S: Serializer,
79{
80    url.to_string().serialize(serializer)
81}
82
83fn deserialize_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
84where
85    D: Deserializer<'de>,
86{
87    let s = String::deserialize(deserializer)?;
88    Url::parse(&s).map_err(serde::de::Error::custom)
89}
90
91/// Represents a cached response, including enough information to reconstruct a `Response` object.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93struct CachedResponse {
94    #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
95    url: Url,
96    #[serde(
97        serialize_with = "serialize_statuscode",
98        deserialize_with = "deserialize_statuscode"
99    )]
100    status: StatusCode,
101    #[serde(
102        serialize_with = "serialize_headermap",
103        deserialize_with = "deserialize_headermap"
104    )]
105    headers: HeaderMap,
106    body: Vec<u8>,
107    #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
108    request_url: Url,
109}
110
111impl From<Response> for CachedResponse {
112    fn from(response: Response) -> Self {
113        CachedResponse {
114            url: response.url,
115            status: response.status,
116            headers: response.headers,
117            body: response.body.to_vec(),
118            request_url: response.request_url,
119        }
120    }
121}
122
123impl From<CachedResponse> for Response {
124    fn from(cached_response: CachedResponse) -> Self {
125        Response {
126            url: cached_response.url,
127            status: cached_response.status,
128            headers: cached_response.headers,
129            body: Bytes::from(cached_response.body),
130            request_url: cached_response.request_url,
131            meta: Default::default(),
132        }
133    }
134}
135
136/// Builder for `HttpCacheMiddleware`.
137#[derive(Default)]
138pub struct HttpCacheMiddlewareBuilder {
139    cache_dir: Option<PathBuf>,
140}
141
142impl HttpCacheMiddlewareBuilder {
143    /// Sets the directory where cache files will be stored.
144    pub fn cache_dir(mut self, path: PathBuf) -> Self {
145        self.cache_dir = Some(path);
146        self
147    }
148
149    /// Builds the `HttpCacheMiddleware`.
150    /// This can fail if the cache directory cannot be created or determined.
151    pub fn build(self) -> Result<HttpCacheMiddleware, SpiderError> {
152        let cache_dir = if let Some(path) = self.cache_dir {
153            path
154        } else {
155            dirs::cache_dir()
156                .ok_or_else(|| {
157                    SpiderError::ConfigurationError(
158                        "Could not determine cache directory".to_string(),
159                    )
160                })?
161                .join("spider-lib")
162                .join("http_cache")
163        };
164
165        utils::create_dir(&cache_dir)?;
166
167        let middleware = HttpCacheMiddleware { cache_dir };
168        info!(
169            "Initializing HttpCacheMiddleware with config: {:?}",
170            middleware
171        );
172
173        Ok(middleware)
174    }
175}
176
177#[derive(Debug)]
178pub struct HttpCacheMiddleware {
179    cache_dir: PathBuf,
180}
181
182impl HttpCacheMiddleware {
183    /// Creates a new `HttpCacheMiddlewareBuilder` to start building an `HttpCacheMiddleware`.
184    pub fn builder() -> HttpCacheMiddlewareBuilder {
185        HttpCacheMiddlewareBuilder::default()
186    }
187
188    fn get_cache_file_path(&self, fingerprint: &str) -> PathBuf {
189        self.cache_dir.join(format!("{}.bin", fingerprint))
190    }
191}
192
193#[async_trait]
194impl<C: Send + Sync> Middleware<C> for HttpCacheMiddleware {
195    fn name(&self) -> &str {
196        "HttpCacheMiddleware"
197    }
198
199    async fn process_request(
200        &mut self,
201        _client: &C,
202        request: Request,
203    ) -> Result<MiddlewareAction<Request>, SpiderError> {
204        let fingerprint = request.fingerprint();
205        let cache_file_path = self.get_cache_file_path(&fingerprint);
206
207        if fs::metadata(&cache_file_path).await.is_ok() {
208            debug!("Cache hit for request: {}", request.url);
209            match fs::read(&cache_file_path).await {
210                Ok(cached_bytes) => match bincode::deserialize::<CachedResponse>(&cached_bytes) {
211                    Ok(cached_resp) => {
212                        let mut response: Response = cached_resp.into();
213                        response.meta = request.meta;
214                        debug!("Returning cached response for {}", response.url);
215                        return Ok(MiddlewareAction::ReturnResponse(response));
216                    }
217                    Err(e) => {
218                        warn!(
219                            "Failed to deserialize cached response from {}: {}. Deleting invalid cache file.",
220                            cache_file_path.display(),
221                            e
222                        );
223                        fs::remove_file(&cache_file_path).await.ok();
224                    }
225                },
226                Err(e) => {
227                    warn!(
228                        "Failed to read cache file {}: {}. Deleting invalid cache file.",
229                        cache_file_path.display(),
230                        e
231                    );
232                    fs::remove_file(&cache_file_path).await.ok();
233                }
234            }
235        }
236
237        debug!("Cache miss for request: {}", request.url);
238        Ok(MiddlewareAction::Continue(request))
239    }
240
241    async fn process_response(
242        &mut self,
243        response: Response,
244    ) -> Result<MiddlewareAction<Response>, SpiderError> {
245        // Only cache successful responses (e.g., 200 OK)
246        if response.status.is_success() {
247            let original_request_fingerprint = response.request_from_response().fingerprint();
248            let cache_file_path = self.get_cache_file_path(&original_request_fingerprint);
249
250            let cached_response: CachedResponse = response.clone().into();
251            match bincode::serialize(&cached_response) {
252                Ok(serialized_bytes) => {
253                    fs::write(&cache_file_path, serialized_bytes)
254                        .await
255                        .map_err(|e| SpiderError::IoError(e.to_string()))?;
256                    debug!("Cached response for {}", response.url);
257                }
258                Err(e) => {
259                    warn!(
260                        "Failed to serialize response for caching {}: {}",
261                        response.url, e
262                    );
263                }
264            }
265        }
266
267        Ok(MiddlewareAction::Continue(response))
268    }
269}