Skip to main content

spider_lib/
request.rs

1use crate::error::SpiderError;
2use bytes::Bytes;
3use dashmap::DashMap;
4use hex;
5use http::header::HeaderMap;
6use reqwest::{Method, Url};
7use serde::{Deserialize, Serialize};
8use serde_json;
9use serde_with::{DisplayFromStr, serde_as};
10use sha2::{Digest, Sha256};
11use std::borrow::Cow;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum Body {
15    Json(serde_json::Value),
16    Form(DashMap<String, String>),
17    Bytes(Bytes),
18}
19
20#[serde_as]
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Request {
23    pub url: Url,
24    #[serde_as(as = "DisplayFromStr")]
25    pub method: Method,
26    #[serde(
27        serialize_with = "header_serde::serialize_headers",
28        deserialize_with = "header_serde::deserialize_headers"
29    )]
30    pub headers: HeaderMap,
31    pub body: Option<Body>,
32    #[serde(skip)]
33    pub meta: DashMap<Cow<'static, str>, serde_json::Value>,
34}
35
36impl Default for Request {
37    fn default() -> Self {
38        Self {
39            url: Url::parse("http://default.invalid").unwrap(),
40            method: Method::GET,
41            headers: HeaderMap::new(),
42            body: None,
43            meta: DashMap::new(),
44        }
45    }
46}
47
48impl Request {
49    /// Creates a new `Request` with the given URL.
50    pub fn new(url: Url) -> Self {
51        Request {
52            url,
53            method: Method::GET,
54            headers: HeaderMap::new(),
55            body: None,
56            meta: DashMap::new(),
57        }
58    }
59
60    /// Sets the HTTP method for the request.
61    pub fn with_method(mut self, method: Method) -> Self {
62        self.method = method;
63        self
64    }
65
66    /// Adds a header to the request.
67    pub fn with_header(mut self, name: &str, value: &str) -> Result<Self, SpiderError> {
68        let header_name =
69            reqwest::header::HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
70                SpiderError::HeaderValueError(format!("Invalid header name '{}': {}", name, e))
71            })?;
72        let header_value = reqwest::header::HeaderValue::from_str(value).map_err(|e| {
73            SpiderError::HeaderValueError(format!("Invalid header value '{}': {}", value, e))
74        })?;
75
76        self.headers.insert(header_name, header_value);
77        Ok(self)
78    }
79
80    /// Sets the body of the request and defaults the method to POST.
81    pub fn with_body(mut self, body: Body) -> Self {
82        self.body = Some(body);
83        self.with_method(Method::POST)
84    }
85
86    /// Sets the body of the request to a JSON value.
87    pub fn with_json(self, json: serde_json::Value) -> Self {
88        self.with_body(Body::Json(json))
89    }
90
91    /// Sets the body of the request to a form.
92    pub fn with_form(self, form: DashMap<String, String>) -> Self {
93        self.with_body(Body::Form(form))
94    }
95
96    /// Sets the body of the request to a byte slice.
97    pub fn with_bytes(self, bytes: Bytes) -> Self {
98        self.with_body(Body::Bytes(bytes))
99    }
100
101    /// Adds a value to the request's metadata.
102    pub fn with_meta(self, key: &str, value: serde_json::Value) -> Self {
103        self.meta.insert(Cow::Owned(key.to_owned()), value);
104        self
105    }
106
107    const RETRY_ATTEMPTS_KEY: &str = "retry_attempts";
108
109    /// Gets the number of times the request has been retried.
110    pub fn get_retry_attempts(&self) -> u32 {
111        self.meta
112            .get(Self::RETRY_ATTEMPTS_KEY)
113            .and_then(|v| v.value().as_u64())
114            .unwrap_or(0) as u32
115    }
116
117    /// Increments the retry count for the request.
118    pub fn increment_retry_attempts(&mut self) {
119        let current_attempts = self.get_retry_attempts();
120        self.meta.insert(
121            Cow::Borrowed(Self::RETRY_ATTEMPTS_KEY),
122            serde_json::to_value(current_attempts + 1).unwrap(),
123        );
124    }
125
126    /// Generates a unique fingerprint for the request based on its URL, method, and body.
127    pub fn fingerprint(&self) -> String {
128        let mut hasher = Sha256::new();
129        hasher.update(self.url.as_str().as_bytes());
130        hasher.update(self.method.as_str().as_bytes());
131
132        if let Some(ref body) = self.body {
133            match body {
134                Body::Json(json_val) => {
135                    if let Ok(serialized) = serde_json::to_string(json_val) {
136                        hasher.update(serialized.as_bytes());
137                    }
138                }
139                Body::Form(form_val) => {
140                    let mut form_string = String::new();
141                    for r in form_val.iter() {
142                        form_string.push_str(r.key());
143                        form_string.push_str(r.value());
144                    }
145                    hasher.update(form_string.as_bytes());
146                }
147                Body::Bytes(bytes_val) => {
148                    hasher.update(bytes_val);
149                }
150            }
151        }
152        hex::encode(hasher.finalize())
153    }
154}
155
156mod header_serde {
157    use super::*;
158    use reqwest::header::{HeaderName, HeaderValue};
159    use serde::{Deserializer, Serializer};
160    use std::str::FromStr;
161
162    pub fn serialize_headers<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
163    where
164        S: Serializer,
165    {
166        let map: Vec<(&str, &str)> = headers
167            .iter()
168            .filter_map(|(name, value)| {
169                value
170                    .to_str()
171                    .ok()
172                    .map(|value_str| (name.as_str(), value_str))
173            })
174            .collect();
175        map.serialize(serializer)
176    }
177
178    pub fn deserialize_headers<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
179    where
180        D: Deserializer<'de>,
181    {
182        let vec = Vec::<(&str, &str)>::deserialize(deserializer)?;
183        let mut headers = HeaderMap::new();
184        for (name, value) in vec {
185            let header_name = HeaderName::from_str(name).map_err(serde::de::Error::custom)?;
186            let header_value = HeaderValue::from_str(value).map_err(serde::de::Error::custom)?;
187            headers.insert(header_name, header_value);
188        }
189        Ok(headers)
190    }
191}