Skip to main content

spider_lib/middlewares/
http_cache.rs

1use async_trait::async_trait;
2use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
3use reqwest::StatusCode;
4use std::path::PathBuf;
5use tokio::fs;
6use tracing::{debug, info, warn};
7
8use crate::error::SpiderError;
9use crate::middleware::{Middleware, MiddlewareAction};
10use crate::request::Request;
11use crate::response::Response;
12use crate::utils;
13use bytes::Bytes;
14use serde::{Deserialize, Deserializer, Serialize, Serializer};
15use url::Url;
16
17fn serialize_headermap<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
18where
19    S: Serializer,
20{
21    let mut map = std::collections::HashMap::<String, String>::new();
22    for (name, value) in headers.iter() {
23        map.insert(
24            name.to_string(),
25            value.to_str().unwrap_or_default().to_string(),
26        );
27    }
28    map.serialize(serializer)
29}
30
31fn deserialize_headermap<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
32where
33    D: Deserializer<'de>,
34{
35    let map = std::collections::HashMap::<String, String>::deserialize(deserializer)?;
36    let mut headers = HeaderMap::new();
37    for (name, value) in map {
38        if let (Ok(header_name), Ok(header_value)) =
39            (name.parse::<HeaderName>(), value.parse::<HeaderValue>())
40        {
41            headers.insert(header_name, header_value);
42        } else {
43            warn!("Failed to parse header: {} = {}", name, value);
44        }
45    }
46    Ok(headers)
47}
48
49fn serialize_statuscode<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
50where
51    S: Serializer,
52{
53    status.as_u16().serialize(serializer)
54}
55
56fn deserialize_statuscode<'de, D>(deserializer: D) -> Result<StatusCode, D::Error>
57where
58    D: Deserializer<'de>,
59{
60    let status_u16 = u16::deserialize(deserializer)?;
61    StatusCode::from_u16(status_u16).map_err(serde::de::Error::custom)
62}
63
64fn serialize_url<S>(url: &Url, serializer: S) -> Result<S::Ok, S::Error>
65where
66    S: Serializer,
67{
68    url.to_string().serialize(serializer)
69}
70
71fn deserialize_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
72where
73    D: Deserializer<'de>,
74{
75    let s = String::deserialize(deserializer)?;
76    Url::parse(&s).map_err(serde::de::Error::custom)
77}
78
79/// Represents a cached response, including enough information to reconstruct a `Response` object.
80#[derive(Debug, Clone, Serialize, Deserialize)]
81struct CachedResponse {
82    #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
83    url: Url,
84    #[serde(
85        serialize_with = "serialize_statuscode",
86        deserialize_with = "deserialize_statuscode"
87    )]
88    status: StatusCode,
89    #[serde(
90        serialize_with = "serialize_headermap",
91        deserialize_with = "deserialize_headermap"
92    )]
93    headers: HeaderMap,
94    body: Vec<u8>,
95    #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
96    request_url: Url,
97}
98
99impl From<Response> for CachedResponse {
100    fn from(response: Response) -> Self {
101        CachedResponse {
102            url: response.url,
103            status: response.status,
104            headers: response.headers,
105            body: response.body.to_vec(),
106            request_url: response.request_url,
107        }
108    }
109}
110
111impl From<CachedResponse> for Response {
112    fn from(cached_response: CachedResponse) -> Self {
113        Response {
114            url: cached_response.url,
115            status: cached_response.status,
116            headers: cached_response.headers,
117            body: Bytes::from(cached_response.body),
118            request_url: cached_response.request_url,
119            meta: Default::default(),
120        }
121    }
122}
123
124/// Builder for `HttpCacheMiddleware`.
125#[derive(Default)]
126pub struct HttpCacheMiddlewareBuilder {
127    cache_dir: Option<PathBuf>,
128}
129
130impl HttpCacheMiddlewareBuilder {
131    /// Sets the directory where cache files will be stored.
132    pub fn cache_dir(mut self, path: PathBuf) -> Self {
133        self.cache_dir = Some(path);
134        self
135    }
136
137    /// Builds the `HttpCacheMiddleware`.
138    /// This can fail if the cache directory cannot be created or determined.
139    pub fn build(self) -> Result<HttpCacheMiddleware, SpiderError> {
140        let cache_dir = if let Some(path) = self.cache_dir {
141            path
142        } else {
143            dirs::cache_dir()
144                .ok_or_else(|| {
145                    SpiderError::ConfigurationError(
146                        "Could not determine cache directory".to_string(),
147                    )
148                })?
149                .join("spider-lib")
150                .join("http_cache")
151        };
152
153        utils::create_dir(&cache_dir)?;
154
155        let middleware = HttpCacheMiddleware { cache_dir };
156        info!(
157            "Initializing HttpCacheMiddleware with config: {:?}",
158            middleware
159        );
160
161        Ok(middleware)
162    }
163}
164
165#[derive(Debug)]
166pub struct HttpCacheMiddleware {
167    cache_dir: PathBuf,
168}
169
170impl HttpCacheMiddleware {
171    /// Creates a new `HttpCacheMiddlewareBuilder` to start building an `HttpCacheMiddleware`.
172    pub fn builder() -> HttpCacheMiddlewareBuilder {
173        HttpCacheMiddlewareBuilder::default()
174    }
175
176    fn get_cache_file_path(&self, fingerprint: &str) -> PathBuf {
177        self.cache_dir.join(format!("{}.bin", fingerprint))
178    }
179}
180
181#[async_trait]
182impl<C: Send + Sync> Middleware<C> for HttpCacheMiddleware {
183    fn name(&self) -> &str {
184        "HttpCacheMiddleware"
185    }
186
187    async fn process_request(
188        &mut self,
189        _client: &C,
190        request: Request,
191    ) -> Result<MiddlewareAction<Request>, SpiderError> {
192        let fingerprint = request.fingerprint();
193        let cache_file_path = self.get_cache_file_path(&fingerprint);
194
195        if fs::metadata(&cache_file_path).await.is_ok() {
196            debug!("Cache hit for request: {}", request.url);
197            match fs::read(&cache_file_path).await {
198                Ok(cached_bytes) => match bincode::deserialize::<CachedResponse>(&cached_bytes) {
199                    Ok(cached_resp) => {
200                        let mut response: Response = cached_resp.into();
201                        response.meta = request.meta;
202                        debug!("Returning cached response for {}", response.url);
203                        return Ok(MiddlewareAction::ReturnResponse(response));
204                    }
205                    Err(e) => {
206                        warn!(
207                            "Failed to deserialize cached response from {}: {}. Deleting invalid cache file.",
208                            cache_file_path.display(),
209                            e
210                        );
211                        fs::remove_file(&cache_file_path).await.ok();
212                    }
213                },
214                Err(e) => {
215                    warn!(
216                        "Failed to read cache file {}: {}. Deleting invalid cache file.",
217                        cache_file_path.display(),
218                        e
219                    );
220                    fs::remove_file(&cache_file_path).await.ok();
221                }
222            }
223        }
224
225        debug!("Cache miss for request: {}", request.url);
226        Ok(MiddlewareAction::Continue(request))
227    }
228
229    async fn process_response(
230        &mut self,
231        response: Response,
232    ) -> Result<MiddlewareAction<Response>, SpiderError> {
233        // Only cache successful responses (e.g., 200 OK)
234        if response.status.is_success() {
235            let original_request_fingerprint = response.request_from_response().fingerprint();
236            let cache_file_path = self.get_cache_file_path(&original_request_fingerprint);
237
238            let cached_response: CachedResponse = response.clone().into();
239            match bincode::serialize(&cached_response) {
240                Ok(serialized_bytes) => {
241                    fs::write(&cache_file_path, serialized_bytes)
242                        .await
243                        .map_err(SpiderError::IoError)?;
244                    debug!("Cached response for {}", response.url);
245                }
246                Err(e) => {
247                    warn!(
248                        "Failed to serialize response for caching {}: {}",
249                        response.url, e
250                    );
251                }
252            }
253        }
254
255        Ok(MiddlewareAction::Continue(response))
256    }
257}