1use async_trait::async_trait;
2use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
3use reqwest::StatusCode;
4use std::path::PathBuf;
5use tokio::fs;
6use tracing::{debug, info, warn};
7
8use crate::error::SpiderError;
9use crate::middleware::{Middleware, MiddlewareAction};
10use crate::request::Request;
11use crate::response::Response;
12use crate::utils;
13use bytes::Bytes;
14use serde::{Deserialize, Deserializer, Serialize, Serializer};
15use url::Url;
16
17fn serialize_headermap<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
18where
19 S: Serializer,
20{
21 let mut map = std::collections::HashMap::<String, String>::new();
22 for (name, value) in headers.iter() {
23 map.insert(
24 name.to_string(),
25 value.to_str().unwrap_or_default().to_string(),
26 );
27 }
28 map.serialize(serializer)
29}
30
31fn deserialize_headermap<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
32where
33 D: Deserializer<'de>,
34{
35 let map = std::collections::HashMap::<String, String>::deserialize(deserializer)?;
36 let mut headers = HeaderMap::new();
37 for (name, value) in map {
38 if let (Ok(header_name), Ok(header_value)) =
39 (name.parse::<HeaderName>(), value.parse::<HeaderValue>())
40 {
41 headers.insert(header_name, header_value);
42 } else {
43 warn!("Failed to parse header: {} = {}", name, value);
44 }
45 }
46 Ok(headers)
47}
48
49fn serialize_statuscode<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
50where
51 S: Serializer,
52{
53 status.as_u16().serialize(serializer)
54}
55
56fn deserialize_statuscode<'de, D>(deserializer: D) -> Result<StatusCode, D::Error>
57where
58 D: Deserializer<'de>,
59{
60 let status_u16 = u16::deserialize(deserializer)?;
61 StatusCode::from_u16(status_u16).map_err(serde::de::Error::custom)
62}
63
64fn serialize_url<S>(url: &Url, serializer: S) -> Result<S::Ok, S::Error>
65where
66 S: Serializer,
67{
68 url.to_string().serialize(serializer)
69}
70
71fn deserialize_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
72where
73 D: Deserializer<'de>,
74{
75 let s = String::deserialize(deserializer)?;
76 Url::parse(&s).map_err(serde::de::Error::custom)
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81struct CachedResponse {
82 #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
83 url: Url,
84 #[serde(
85 serialize_with = "serialize_statuscode",
86 deserialize_with = "deserialize_statuscode"
87 )]
88 status: StatusCode,
89 #[serde(
90 serialize_with = "serialize_headermap",
91 deserialize_with = "deserialize_headermap"
92 )]
93 headers: HeaderMap,
94 body: Vec<u8>,
95 #[serde(serialize_with = "serialize_url", deserialize_with = "deserialize_url")]
96 request_url: Url,
97}
98
99impl From<Response> for CachedResponse {
100 fn from(response: Response) -> Self {
101 CachedResponse {
102 url: response.url,
103 status: response.status,
104 headers: response.headers,
105 body: response.body.to_vec(),
106 request_url: response.request_url,
107 }
108 }
109}
110
111impl From<CachedResponse> for Response {
112 fn from(cached_response: CachedResponse) -> Self {
113 Response {
114 url: cached_response.url,
115 status: cached_response.status,
116 headers: cached_response.headers,
117 body: Bytes::from(cached_response.body),
118 request_url: cached_response.request_url,
119 meta: Default::default(),
120 }
121 }
122}
123
124#[derive(Default)]
126pub struct HttpCacheMiddlewareBuilder {
127 cache_dir: Option<PathBuf>,
128}
129
130impl HttpCacheMiddlewareBuilder {
131 pub fn cache_dir(mut self, path: PathBuf) -> Self {
133 self.cache_dir = Some(path);
134 self
135 }
136
137 pub fn build(self) -> Result<HttpCacheMiddleware, SpiderError> {
140 let cache_dir = if let Some(path) = self.cache_dir {
141 path
142 } else {
143 dirs::cache_dir()
144 .ok_or_else(|| {
145 SpiderError::ConfigurationError(
146 "Could not determine cache directory".to_string(),
147 )
148 })?
149 .join("spider-lib")
150 .join("http_cache")
151 };
152
153 utils::create_dir(&cache_dir)?;
154
155 let middleware = HttpCacheMiddleware { cache_dir };
156 info!(
157 "Initializing HttpCacheMiddleware with config: {:?}",
158 middleware
159 );
160
161 Ok(middleware)
162 }
163}
164
165#[derive(Debug)]
166pub struct HttpCacheMiddleware {
167 cache_dir: PathBuf,
168}
169
170impl HttpCacheMiddleware {
171 pub fn builder() -> HttpCacheMiddlewareBuilder {
173 HttpCacheMiddlewareBuilder::default()
174 }
175
176 fn get_cache_file_path(&self, fingerprint: &str) -> PathBuf {
177 self.cache_dir.join(format!("{}.bin", fingerprint))
178 }
179}
180
181#[async_trait]
182impl<C: Send + Sync> Middleware<C> for HttpCacheMiddleware {
183 fn name(&self) -> &str {
184 "HttpCacheMiddleware"
185 }
186
187 async fn process_request(
188 &mut self,
189 _client: &C,
190 request: Request,
191 ) -> Result<MiddlewareAction<Request>, SpiderError> {
192 let fingerprint = request.fingerprint();
193 let cache_file_path = self.get_cache_file_path(&fingerprint);
194
195 if fs::metadata(&cache_file_path).await.is_ok() {
196 debug!("Cache hit for request: {}", request.url);
197 match fs::read(&cache_file_path).await {
198 Ok(cached_bytes) => match bincode::deserialize::<CachedResponse>(&cached_bytes) {
199 Ok(cached_resp) => {
200 let mut response: Response = cached_resp.into();
201 response.meta = request.meta;
202 debug!("Returning cached response for {}", response.url);
203 return Ok(MiddlewareAction::ReturnResponse(response));
204 }
205 Err(e) => {
206 warn!(
207 "Failed to deserialize cached response from {}: {}. Deleting invalid cache file.",
208 cache_file_path.display(),
209 e
210 );
211 fs::remove_file(&cache_file_path).await.ok();
212 }
213 },
214 Err(e) => {
215 warn!(
216 "Failed to read cache file {}: {}. Deleting invalid cache file.",
217 cache_file_path.display(),
218 e
219 );
220 fs::remove_file(&cache_file_path).await.ok();
221 }
222 }
223 }
224
225 debug!("Cache miss for request: {}", request.url);
226 Ok(MiddlewareAction::Continue(request))
227 }
228
229 async fn process_response(
230 &mut self,
231 response: Response,
232 ) -> Result<MiddlewareAction<Response>, SpiderError> {
233 if response.status.is_success() {
235 let original_request_fingerprint = response.request_from_response().fingerprint();
236 let cache_file_path = self.get_cache_file_path(&original_request_fingerprint);
237
238 let cached_response: CachedResponse = response.clone().into();
239 match bincode::serialize(&cached_response) {
240 Ok(serialized_bytes) => {
241 fs::write(&cache_file_path, serialized_bytes)
242 .await
243 .map_err(SpiderError::IoError)?;
244 debug!("Cached response for {}", response.url);
245 }
246 Err(e) => {
247 warn!(
248 "Failed to serialize response for caching {}: {}",
249 response.url, e
250 );
251 }
252 }
253 }
254
255 Ok(MiddlewareAction::Continue(response))
256 }
257}