1use async_trait::async_trait;
14use reqwest::StatusCode;
15use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
16use std::path::PathBuf;
17use tokio::fs;
18use tracing::{debug, info, warn};
19
20use crate::error::SpiderError;
21use crate::middleware::{Middleware, MiddlewareAction};
22use crate::request::Request;
23use crate::response::Response;
24use crate::utils;
25use bytes::Bytes;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27use url::Url;
28
29fn serialize_headermap<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
30where
31 S: Serializer,
32{
33 let mut map = std::collections::HashMap::<String, String>::new();
34 for (name, value) in headers.iter() {
35 map.insert(
36 name.to_string(),
37 value.to_str().unwrap_or_default().to_string(),
38 );
39 }
40 map.serialize(serializer)
41}
42
43fn deserialize_headermap<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
44where
45 D: Deserializer<'de>,
46{
47 let map = std::collections::HashMap::<String, String>::deserialize(deserializer)?;
48 let mut headers = HeaderMap::new();
49 for (name, value) in map {
50 if let (Ok(header_name), Ok(header_value)) =
51 (name.parse::<HeaderName>(), value.parse::<HeaderValue>())
52 {
53 headers.insert(header_name, header_value);
54 } else {
55 warn!("Failed to parse header: {} = {}", name, value);
56 }
57 }
58 Ok(headers)
59}
60
61fn serialize_statuscode<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
62where
63 S: Serializer,
64{
65 status.as_u16().serialize(serializer)
66}
67
68fn deserialize_statuscode<'de, D>(deserializer: D) -> Result<StatusCode, D::Error>
69where
70 D: Deserializer<'de>,
71{
72 let status_u16 = u16::deserialize(deserializer)?;
73 StatusCode::from_u16(status_u16).map_err(serde::de::Error::custom)
74}
75
76fn serialize_url<S>(url: &Url, serializer: S) -> Result<S::Ok, S::Error>
77where
78 S: Serializer,
79{
80 url.to_string().serialize(serializer)
81}
82
83fn deserialize_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
84where
85 D: Deserializer<'de>,
86{
87 let s = String::deserialize(deserializer)?;
88 Url::parse(&s).map_err(serde::de::Error::custom)
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93struct CachedResponse {
94 #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
95 url: Url,
96 #[serde(
97 serialize_with = "serialize_statuscode",
98 deserialize_with = "deserialize_statuscode"
99 )]
100 status: StatusCode,
101 #[serde(
102 serialize_with = "serialize_headermap",
103 deserialize_with = "deserialize_headermap"
104 )]
105 headers: HeaderMap,
106 body: Vec<u8>,
107 #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
108 request_url: Url,
109}
110
111impl From<Response> for CachedResponse {
112 fn from(response: Response) -> Self {
113 CachedResponse {
114 url: response.url,
115 status: response.status,
116 headers: response.headers,
117 body: response.body.to_vec(),
118 request_url: response.request_url,
119 }
120 }
121}
122
123impl From<CachedResponse> for Response {
124 fn from(cached_response: CachedResponse) -> Self {
125 Response {
126 url: cached_response.url,
127 status: cached_response.status,
128 headers: cached_response.headers,
129 body: Bytes::from(cached_response.body),
130 request_url: cached_response.request_url,
131 meta: Default::default(),
132 }
133 }
134}
135
136#[derive(Default)]
138pub struct HttpCacheMiddlewareBuilder {
139 cache_dir: Option<PathBuf>,
140}
141
142impl HttpCacheMiddlewareBuilder {
143 pub fn cache_dir(mut self, path: PathBuf) -> Self {
145 self.cache_dir = Some(path);
146 self
147 }
148
149 pub fn build(self) -> Result<HttpCacheMiddleware, SpiderError> {
152 let cache_dir = if let Some(path) = self.cache_dir {
153 path
154 } else {
155 dirs::cache_dir()
156 .ok_or_else(|| {
157 SpiderError::ConfigurationError(
158 "Could not determine cache directory".to_string(),
159 )
160 })?
161 .join("spider-lib")
162 .join("http_cache")
163 };
164
165 utils::create_dir(&cache_dir)?;
166
167 let middleware = HttpCacheMiddleware { cache_dir };
168 info!(
169 "Initializing HttpCacheMiddleware with config: {:?}",
170 middleware
171 );
172
173 Ok(middleware)
174 }
175}
176
177#[derive(Debug)]
178pub struct HttpCacheMiddleware {
179 cache_dir: PathBuf,
180}
181
182impl HttpCacheMiddleware {
183 pub fn builder() -> HttpCacheMiddlewareBuilder {
185 HttpCacheMiddlewareBuilder::default()
186 }
187
188 fn get_cache_file_path(&self, fingerprint: &str) -> PathBuf {
189 self.cache_dir.join(format!("{}.bin", fingerprint))
190 }
191}
192
193#[async_trait]
194impl<C: Send + Sync> Middleware<C> for HttpCacheMiddleware {
195 fn name(&self) -> &str {
196 "HttpCacheMiddleware"
197 }
198
199 async fn process_request(
200 &mut self,
201 _client: &C,
202 request: Request,
203 ) -> Result<MiddlewareAction<Request>, SpiderError> {
204 let fingerprint = request.fingerprint();
205 let cache_file_path = self.get_cache_file_path(&fingerprint);
206
207 if fs::metadata(&cache_file_path).await.is_ok() {
208 debug!("Cache hit for request: {}", request.url);
209 match fs::read(&cache_file_path).await {
210 Ok(cached_bytes) => match bincode::deserialize::<CachedResponse>(&cached_bytes) {
211 Ok(cached_resp) => {
212 let mut response: Response = cached_resp.into();
213 response.meta = request.meta;
214 debug!("Returning cached response for {}", response.url);
215 return Ok(MiddlewareAction::ReturnResponse(response));
216 }
217 Err(e) => {
218 warn!(
219 "Failed to deserialize cached response from {}: {}. Deleting invalid cache file.",
220 cache_file_path.display(),
221 e
222 );
223 fs::remove_file(&cache_file_path).await.ok();
224 }
225 },
226 Err(e) => {
227 warn!(
228 "Failed to read cache file {}: {}. Deleting invalid cache file.",
229 cache_file_path.display(),
230 e
231 );
232 fs::remove_file(&cache_file_path).await.ok();
233 }
234 }
235 }
236
237 debug!("Cache miss for request: {}", request.url);
238 Ok(MiddlewareAction::Continue(request))
239 }
240
241 async fn process_response(
242 &mut self,
243 response: Response,
244 ) -> Result<MiddlewareAction<Response>, SpiderError> {
245 if response.status.is_success() {
247 let original_request_fingerprint = response.request_from_response().fingerprint();
248 let cache_file_path = self.get_cache_file_path(&original_request_fingerprint);
249
250 let cached_response: CachedResponse = response.clone().into();
251 match bincode::serialize(&cached_response) {
252 Ok(serialized_bytes) => {
253 fs::write(&cache_file_path, serialized_bytes)
254 .await
255 .map_err(|e| SpiderError::IoError(e.to_string()))?;
256 debug!("Cached response for {}", response.url);
257 }
258 Err(e) => {
259 warn!(
260 "Failed to serialize response for caching {}: {}",
261 response.url, e
262 );
263 }
264 }
265 }
266
267 Ok(MiddlewareAction::Continue(response))
268 }
269}