spider_lib/middlewares/
referer.rs1use async_trait::async_trait;
15use dashmap::DashMap;
16use reqwest::header::{HeaderValue, REFERER};
17use std::sync::Arc;
18use url::Url;
19
20use crate::error::SpiderError;
21use crate::middleware::{Middleware, MiddlewareAction};
22use crate::request::Request;
23use crate::response::Response;
24use tracing::{debug, info};
25
26#[derive(Debug, Clone)]
29pub struct RefererMiddleware {
30 pub same_origin_only: bool,
32 pub max_chain_length: usize,
34 pub include_fragment: bool,
36 referer_map: Arc<DashMap<String, Url>>,
38}
39
40impl Default for RefererMiddleware {
41 fn default() -> Self {
42 let middleware = RefererMiddleware {
43 same_origin_only: true,
44 max_chain_length: 1000,
45 include_fragment: false,
46 referer_map: Arc::new(DashMap::new()),
47 };
48 info!(
49 "Initializing RefererMiddleware with config: {:?}",
50 middleware
51 );
52 middleware
53 }
54}
55
56impl RefererMiddleware {
57 pub fn new() -> Self {
59 Self::default()
60 }
61
62 pub fn same_origin_only(mut self, same_origin_only: bool) -> Self {
64 self.same_origin_only = same_origin_only;
65 self
66 }
67
68 pub fn max_chain_length(mut self, max_chain_length: usize) -> Self {
70 self.max_chain_length = max_chain_length;
71 self
72 }
73
74 pub fn include_fragment(mut self, include_fragment: bool) -> Self {
76 self.include_fragment = include_fragment;
77 self
78 }
79
80 fn get_referer_for_request(&self, request: &Request) -> Option<Url> {
82 if let Some(referer_value) = request.meta.get("referer")
83 && let Some(referer_str) = referer_value.value().as_str()
84 && let Ok(url) = Url::parse(referer_str)
85 {
86 if self.same_origin_only {
87 let request_origin = format!(
88 "{}://{}",
89 request.url.scheme(),
90 request.url.host_str().unwrap_or("")
91 );
92 let referer_origin = format!("{}://{}", url.scheme(), url.host_str().unwrap_or(""));
93
94 if request_origin == referer_origin {
95 return Some(self.clean_url(&url));
96 }
97 } else {
98 return Some(self.clean_url(&url));
99 }
100 }
101
102 None
103 }
104
105 fn clean_url(&self, url: &Url) -> Url {
107 if !self.include_fragment && url.fragment().is_some() {
108 let mut cleaned = url.clone();
109 cleaned.set_fragment(None);
110 cleaned
111 } else {
112 url.clone()
113 }
114 }
115
116 fn request_key(&self, request: &Request) -> String {
118 format!("{}:{}", request.method, request.url)
120 }
121}
122
123#[async_trait]
124impl<C: Send + Sync> Middleware<C> for RefererMiddleware {
125 fn name(&self) -> &str {
126 "RefererMiddleware"
127 }
128
129 async fn process_request(
130 &mut self,
131 _client: &C,
132 mut request: Request,
133 ) -> Result<MiddlewareAction<Request>, SpiderError> {
134 let referer = self.get_referer_for_request(&request);
135 let referer = if let Some(ref_from_meta) = referer {
136 Some(ref_from_meta)
137 } else {
138 let request_key = self.request_key(&request);
139 self.referer_map
140 .get(&request_key)
141 .map(|entry| entry.value().clone())
142 };
143
144 let referer = if let Some(ref_url) = referer {
145 Some(ref_url)
146 } else if let Some(parent_id) = request.meta.get("parent_request_id") {
147 if let Some(parent_id_str) = parent_id.value().as_str() {
148 self.referer_map
149 .get(parent_id_str)
150 .map(|entry| entry.value().clone())
151 } else {
152 None
153 }
154 } else {
155 None
156 };
157
158 if let Some(referer) = referer {
159 match HeaderValue::from_str(referer.as_str()) {
160 Ok(header_value) => {
161 request.headers.insert(REFERER, header_value);
162 debug!(
163 "Set Referer header to: {} for request: {}",
164 referer, request.url
165 );
166 }
167 Err(e) => {
168 debug!("Failed to set Referer header: {}", e);
169 }
170 }
171 }
172
173 Ok(MiddlewareAction::Continue(request))
174 }
175
176 async fn process_response(
177 &mut self,
178 response: Response,
179 ) -> Result<MiddlewareAction<Response>, SpiderError> {
180 let response_url = response.url.clone();
181 let request = response.request_from_response();
182 let request_id = format!("req_{:x}", seahash::hash(request.url.as_str().as_bytes()));
183
184 let request_key = self.request_key(&request);
189 let cleaned_url = self.clean_url(&response_url);
190
191 if self.referer_map.len() < self.max_chain_length {
192 self.referer_map.insert(request_key, cleaned_url.clone());
193 self.referer_map.insert(request_id.clone(), cleaned_url);
194
195 debug!(
196 "Stored referer mapping for request {}: {}",
197 request.url, response_url
198 );
199 }
200
201 Ok(MiddlewareAction::Continue(response))
202 }
203}