Skip to main content

spider_util/
request.rs

1//! Data structures for representing HTTP requests in `spider-lib`.
2//!
3//! This module defines the `Request` struct, which is a central component
4//! for constructing and managing outgoing HTTP requests within the
5//! `spider-lib` framework. It encapsulates all necessary details of an
6//! HTTP request, including:
7//! - The target URL and HTTP method.
8//! - Request headers and an optional request body (supporting JSON, form data, or raw bytes).
9//! - Metadata for tracking retry attempts or other custom information.
10//!
11//! Additionally, the module provides methods for building requests,
12//! incrementing retry counters, and generating unique fingerprints
13//! for request deduplication and caching.
14
15use bytes::Bytes;
16use dashmap::DashMap;
17use http::header::HeaderMap;
18use reqwest::{Method, Url};
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use std::borrow::Cow;
22use std::collections::HashMap;
23use std::hash::Hasher;
24use std::str::FromStr;
25use std::sync::Arc;
26use twox_hash::XxHash64;
27
28use crate::error::SpiderError;
29
30#[derive(Debug, Clone)]
31pub enum Body {
32    Json(Value),
33    Form(DashMap<String, String>),
34    Bytes(Bytes),
35}
36
37// Custom serialization for Body enum
38impl Serialize for Body {
39    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
40    where
41        S: serde::Serializer,
42    {
43        use serde::ser::SerializeMap;
44        let mut map = serializer.serialize_map(Some(1))?;
45
46        match self {
47            Body::Json(value) => map.serialize_entry("Json", value)?,
48            Body::Form(dashmap) => {
49                let hmap: HashMap<String, String> = dashmap.clone().into_iter().collect();
50                map.serialize_entry("Form", &hmap)?
51            }
52            Body::Bytes(bytes) => map.serialize_entry("Bytes", bytes)?,
53        }
54
55        map.end()
56    }
57}
58
59impl<'de> Deserialize<'de> for Body {
60    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
61    where
62        D: serde::Deserializer<'de>,
63    {
64        use serde::de::{self, MapAccess, Visitor};
65        use std::fmt;
66
67        struct BodyVisitor;
68
69        impl<'de> Visitor<'de> for BodyVisitor {
70            type Value = Body;
71
72            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
73                formatter.write_str("a body object")
74            }
75
76            fn visit_map<V>(self, mut map: V) -> Result<Body, V::Error>
77            where
78                V: MapAccess<'de>,
79            {
80                let entry = map.next_entry::<String, Value>()?;
81                let (key, value) = match entry {
82                    Some((k, v)) => (k, v),
83                    None => return Err(de::Error::custom("Expected a body variant")),
84                };
85
86                match key.as_str() {
87                    "Json" => Ok(Body::Json(value)),
88                    "Form" => {
89                        let form_data: HashMap<String, String> =
90                            serde_json::from_value(value).map_err(de::Error::custom)?;
91                        let dashmap = DashMap::new();
92                        for (k, v) in form_data {
93                            dashmap.insert(k, v);
94                        }
95                        Ok(Body::Form(dashmap))
96                    }
97                    "Bytes" => {
98                        let bytes: Bytes =
99                            serde_json::from_value(value).map_err(de::Error::custom)?;
100                        Ok(Body::Bytes(bytes))
101                    }
102                    _ => Err(de::Error::custom(format!("Unknown body variant: {}", key))),
103                }
104            }
105        }
106
107        deserializer.deserialize_map(BodyVisitor)
108    }
109}
110
111#[derive(Debug, Clone)]
112pub struct Request {
113    pub url: Url,
114    pub method: Method,
115    pub headers: HeaderMap,
116    pub body: Option<Body>,
117    pub meta: Arc<DashMap<Cow<'static, str>, Value>>,
118}
119
120// Custom serialization for Request struct
121impl Serialize for Request {
122    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
123    where
124        S: serde::Serializer,
125    {
126        use serde::ser::SerializeStruct;
127        // Convert HeaderMap to a serializable format
128        let headers_vec: Vec<(String, String)> = self
129            .headers
130            .iter()
131            .filter_map(|(name, value)| {
132                value
133                    .to_str()
134                    .ok()
135                    .map(|val_str| (name.as_str().to_string(), val_str.to_string()))
136            })
137            .collect();
138
139        let mut s = serializer.serialize_struct("Request", 5)?;
140        s.serialize_field("url", &self.url.as_str())?;
141        s.serialize_field("method", &self.method.as_str())?;
142        s.serialize_field("headers", &headers_vec)?;
143        s.serialize_field("body", &self.body)?;
144        s.end()
145    }
146}
147
148impl<'de> Deserialize<'de> for Request {
149    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
150    where
151        D: serde::Deserializer<'de>,
152    {
153        use serde::de::{self, MapAccess, Visitor};
154        use std::fmt;
155
156        #[derive(Deserialize)]
157        #[serde(field_identifier, rename_all = "lowercase")]
158        enum Field {
159            Url,
160            Method,
161            Headers,
162            Body,
163        }
164
165        struct RequestVisitor;
166
167        impl<'de> Visitor<'de> for RequestVisitor {
168            type Value = Request;
169
170            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
171                formatter.write_str("struct Request")
172            }
173
174            fn visit_map<V>(self, mut map: V) -> Result<Request, V::Error>
175            where
176                V: MapAccess<'de>,
177            {
178                let mut url = None;
179                let mut method = None;
180                let mut headers = None;
181                let mut body = None;
182
183                while let Some(key) = map.next_key()? {
184                    match key {
185                        Field::Url => {
186                            if url.is_some() {
187                                return Err(de::Error::duplicate_field("url"));
188                            }
189                            let url_str: String = map.next_value()?;
190                            let parsed_url = Url::parse(&url_str).map_err(de::Error::custom)?;
191                            url = Some(parsed_url);
192                        }
193                        Field::Method => {
194                            if method.is_some() {
195                                return Err(de::Error::duplicate_field("method"));
196                            }
197                            let method_str: String = map.next_value()?;
198                            let parsed_method =
199                                Method::from_str(&method_str).map_err(de::Error::custom)?;
200                            method = Some(parsed_method);
201                        }
202                        Field::Headers => {
203                            if headers.is_some() {
204                                return Err(de::Error::duplicate_field("headers"));
205                            }
206                            // Deserialize headers vector and convert back to HeaderMap
207                            let headers_vec: Vec<(String, String)> = map.next_value()?;
208                            let mut header_map = HeaderMap::new();
209                            for (name, value) in headers_vec {
210                                if let Ok(header_name) =
211                                    http::header::HeaderName::from_bytes(name.as_bytes())
212                                    && let Ok(header_value) =
213                                        http::header::HeaderValue::from_str(&value)
214                                {
215                                    header_map.insert(header_name, header_value);
216                                }
217                            }
218                            headers = Some(header_map);
219                        }
220                        Field::Body => {
221                            if body.is_some() {
222                                return Err(de::Error::duplicate_field("body"));
223                            }
224                            body = Some(map.next_value()?);
225                        }
226                    }
227                }
228
229                let url = url.ok_or_else(|| de::Error::missing_field("url"))?;
230                let method = method.ok_or_else(|| de::Error::missing_field("method"))?;
231                let headers = headers.ok_or_else(|| de::Error::missing_field("headers"))?;
232                let body = body; // Optional field
233
234                Ok(Request {
235                    url,
236                    method,
237                    headers,
238                    body,
239                    meta: Arc::new(DashMap::new()), // Initialize empty meta map
240                })
241            }
242        }
243
244        const FIELDS: &[&str] = &["url", "method", "headers", "body"];
245        deserializer.deserialize_struct("Request", FIELDS, RequestVisitor)
246    }
247}
248
249impl Default for Request {
250    fn default() -> Self {
251        Self {
252            url: Url::parse("http://default.invalid").unwrap(),
253            method: Method::GET,
254            headers: HeaderMap::new(),
255            body: None,
256            meta: Arc::new(DashMap::new()),
257        }
258    }
259}
260
261impl Request {
262    /// Creates a new `Request` with the given URL.
263    pub fn new(url: Url) -> Self {
264        Request {
265            url,
266            method: Method::GET,
267            headers: HeaderMap::new(),
268            body: None,
269            meta: Arc::new(DashMap::new()),
270        }
271    }
272
273    /// Sets the HTTP method for the request.
274    pub fn with_method(mut self, method: Method) -> Self {
275        self.method = method;
276        self
277    }
278
279    /// Adds a header to the request.
280    pub fn with_header(mut self, name: &str, value: &str) -> Result<Self, SpiderError> {
281        let header_name =
282            reqwest::header::HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
283                SpiderError::HeaderValueError(format!("Invalid header name '{}': {}", name, e))
284            })?;
285        let header_value = reqwest::header::HeaderValue::from_str(value).map_err(|e| {
286            SpiderError::HeaderValueError(format!("Invalid header value '{}': {}", value, e))
287        })?;
288
289        self.headers.insert(header_name, header_value);
290        Ok(self)
291    }
292
293    /// Sets the body of the request and defaults the method to POST.
294    pub fn with_body(mut self, body: Body) -> Self {
295        self.body = Some(body);
296        self.with_method(Method::POST)
297    }
298
299    /// Sets the body of the request to a JSON value.
300    pub fn with_json(self, json: Value) -> Self {
301        self.with_body(Body::Json(json))
302    }
303
304    /// Sets the body of the request to a form.
305    pub fn with_form(self, form: DashMap<String, String>) -> Self {
306        self.with_body(Body::Form(form))
307    }
308
309    /// Sets the body of the request to a byte slice.
310    pub fn with_bytes(self, bytes: Bytes) -> Self {
311        self.with_body(Body::Bytes(bytes))
312    }
313
314    /// Adds a value to the request's metadata.
315    pub fn with_meta(self, key: &str, value: Value) -> Self {
316        self.meta.insert(Cow::Owned(key.to_owned()), value);
317        self
318    }
319
320    const RETRY_ATTEMPTS_KEY: &str = "retry_attempts";
321
322    /// Gets the number of times the request has been retried.
323    pub fn get_retry_attempts(&self) -> u32 {
324        self.meta
325            .get(Self::RETRY_ATTEMPTS_KEY)
326            .and_then(|v| v.value().as_u64())
327            .unwrap_or(0) as u32
328    }
329
330    /// Increments the retry count for the request.
331    pub fn increment_retry_attempts(&mut self) {
332        let current_attempts = self.get_retry_attempts();
333        self.meta.insert(
334            Cow::Borrowed(Self::RETRY_ATTEMPTS_KEY),
335            Value::from(current_attempts + 1),
336        );
337    }
338
339    /// Generates a unique fingerprint for the request based on its URL, method, and body.
340    pub fn fingerprint(&self) -> String {
341        let mut hasher = XxHash64::default();
342        hasher.write(self.url.as_str().as_bytes());
343        hasher.write(self.method.as_str().as_bytes());
344
345        if let Some(ref body) = self.body {
346            match body {
347                Body::Json(json_val) => {
348                    if let Ok(serialized) = serde_json::to_string(json_val) {
349                        hasher.write(serialized.as_bytes());
350                    }
351                }
352                Body::Form(form_val) => {
353                    let mut form_string = String::new();
354                    for r in form_val.iter() {
355                        form_string.push_str(r.key());
356                        form_string.push_str(r.value());
357                    }
358                    hasher.write(form_string.as_bytes());
359                }
360                Body::Bytes(bytes_val) => {
361                    hasher.write(bytes_val);
362                }
363            }
364        }
365        format!("{:x}", hasher.finish())
366    }
367}