1use crate::error::SpiderError;
2use bytes::Bytes;
3use dashmap::DashMap;
4use hex;
5use http::header::HeaderMap;
6use reqwest::{Method, Url};
7use serde::{Deserialize, Serialize};
8use serde_json;
9use serde_with::{DisplayFromStr, serde_as};
10use sha2::{Digest, Sha256};
11use std::borrow::Cow;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum Body {
15 Json(serde_json::Value),
16 Form(DashMap<String, String>),
17 Bytes(Bytes),
18}
19
20#[serde_as]
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Request {
23 pub url: Url,
24 #[serde_as(as = "DisplayFromStr")]
25 pub method: Method,
26 #[serde(
27 serialize_with = "header_serde::serialize_headers",
28 deserialize_with = "header_serde::deserialize_headers"
29 )]
30 pub headers: HeaderMap,
31 pub body: Option<Body>,
32 #[serde(skip)]
33 pub meta: DashMap<Cow<'static, str>, serde_json::Value>,
34}
35
36impl Default for Request {
37 fn default() -> Self {
38 Self {
39 url: Url::parse("http://default.invalid").unwrap(),
40 method: Method::GET,
41 headers: HeaderMap::new(),
42 body: None,
43 meta: DashMap::new(),
44 }
45 }
46}
47
48impl Request {
49 pub fn new(url: Url) -> Self {
51 Request {
52 url,
53 method: Method::GET,
54 headers: HeaderMap::new(),
55 body: None,
56 meta: DashMap::new(),
57 }
58 }
59
60 pub fn with_method(mut self, method: Method) -> Self {
62 self.method = method;
63 self
64 }
65
66 pub fn with_header(mut self, name: &str, value: &str) -> Result<Self, SpiderError> {
68 let header_name =
69 reqwest::header::HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
70 SpiderError::HeaderValueError(format!("Invalid header name '{}': {}", name, e))
71 })?;
72 let header_value = reqwest::header::HeaderValue::from_str(value).map_err(|e| {
73 SpiderError::HeaderValueError(format!("Invalid header value '{}': {}", value, e))
74 })?;
75
76 self.headers.insert(header_name, header_value);
77 Ok(self)
78 }
79
80 pub fn with_body(mut self, body: Body) -> Self {
82 self.body = Some(body);
83 self.with_method(Method::POST)
84 }
85
86 pub fn with_json(self, json: serde_json::Value) -> Self {
88 self.with_body(Body::Json(json))
89 }
90
91 pub fn with_form(self, form: DashMap<String, String>) -> Self {
93 self.with_body(Body::Form(form))
94 }
95
96 pub fn with_bytes(self, bytes: Bytes) -> Self {
98 self.with_body(Body::Bytes(bytes))
99 }
100
101 pub fn with_meta(self, key: &str, value: serde_json::Value) -> Self {
103 self.meta.insert(Cow::Owned(key.to_owned()), value);
104 self
105 }
106
107 const RETRY_ATTEMPTS_KEY: &str = "retry_attempts";
108
109 pub fn get_retry_attempts(&self) -> u32 {
111 self.meta
112 .get(Self::RETRY_ATTEMPTS_KEY)
113 .and_then(|v| v.value().as_u64())
114 .unwrap_or(0) as u32
115 }
116
117 pub fn increment_retry_attempts(&mut self) {
119 let current_attempts = self.get_retry_attempts();
120 self.meta.insert(
121 Cow::Borrowed(Self::RETRY_ATTEMPTS_KEY),
122 serde_json::to_value(current_attempts + 1).unwrap(),
123 );
124 }
125
126 pub fn fingerprint(&self) -> String {
128 let mut hasher = Sha256::new();
129 hasher.update(self.url.as_str().as_bytes());
130 hasher.update(self.method.as_str().as_bytes());
131
132 if let Some(ref body) = self.body {
133 match body {
134 Body::Json(json_val) => {
135 if let Ok(serialized) = serde_json::to_string(json_val) {
136 hasher.update(serialized.as_bytes());
137 }
138 }
139 Body::Form(form_val) => {
140 let mut form_string = String::new();
141 for r in form_val.iter() {
142 form_string.push_str(r.key());
143 form_string.push_str(r.value());
144 }
145 hasher.update(form_string.as_bytes());
146 }
147 Body::Bytes(bytes_val) => {
148 hasher.update(bytes_val);
149 }
150 }
151 }
152 hex::encode(hasher.finalize())
153 }
154}
155
156mod header_serde {
157 use super::*;
158 use reqwest::header::{HeaderName, HeaderValue};
159 use serde::{Deserializer, Serializer};
160 use std::str::FromStr;
161
162 pub fn serialize_headers<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
163 where
164 S: Serializer,
165 {
166 let map: Vec<(&str, &str)> = headers
167 .iter()
168 .filter_map(|(name, value)| {
169 value
170 .to_str()
171 .ok()
172 .map(|value_str| (name.as_str(), value_str))
173 })
174 .collect();
175 map.serialize(serializer)
176 }
177
178 pub fn deserialize_headers<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
179 where
180 D: Deserializer<'de>,
181 {
182 let vec = Vec::<(&str, &str)>::deserialize(deserializer)?;
183 let mut headers = HeaderMap::new();
184 for (name, value) in vec {
185 let header_name = HeaderName::from_str(name).map_err(serde::de::Error::custom)?;
186 let header_value = HeaderValue::from_str(value).map_err(serde::de::Error::custom)?;
187 headers.insert(header_name, header_value);
188 }
189 Ok(headers)
190 }
191}