spider_middleware/
referer.rs1use async_trait::async_trait;
4use dashmap::DashMap;
5use reqwest::header::{HeaderValue, REFERER};
6use std::sync::Arc;
7use url::Url;
8
9use crate::middleware::{Middleware, MiddlewareAction};
10use log::{debug, info};
11use spider_util::error::SpiderError;
12use spider_util::request::Request;
13use spider_util::response::Response;
14
15#[derive(Debug, Clone)]
17pub struct RefererMiddleware {
18 pub same_origin_only: bool,
20 pub max_chain_length: usize,
22 pub include_fragment: bool,
24 referer_map: Arc<DashMap<String, Url>>,
26}
27
28impl Default for RefererMiddleware {
29 fn default() -> Self {
30 let middleware = RefererMiddleware {
31 same_origin_only: true,
32 max_chain_length: 1000,
33 include_fragment: false,
34 referer_map: Arc::new(DashMap::new()),
35 };
36 info!(
37 "Initializing RefererMiddleware with config: {:?}",
38 middleware
39 );
40 middleware
41 }
42}
43
44impl RefererMiddleware {
45 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn same_origin_only(mut self, same_origin_only: bool) -> Self {
52 self.same_origin_only = same_origin_only;
53 self
54 }
55
56 pub fn max_chain_length(mut self, max_chain_length: usize) -> Self {
58 self.max_chain_length = max_chain_length;
59 self
60 }
61
62 pub fn include_fragment(mut self, include_fragment: bool) -> Self {
64 self.include_fragment = include_fragment;
65 self
66 }
67
68 fn get_referer_for_request(&self, request: &Request) -> Option<Url> {
70 if let Some(referer_value) = request.get_meta_ref("referer")
71 && let Some(referer_str) = referer_value.value().as_str()
72 && let Ok(url) = Url::parse(referer_str)
73 {
74 if self.same_origin_only {
75 let request_origin = format!(
76 "{}://{}",
77 request.url.scheme(),
78 request.url.host_str().unwrap_or("")
79 );
80 let referer_origin = format!("{}://{}", url.scheme(), url.host_str().unwrap_or(""));
81
82 if request_origin == referer_origin {
83 return Some(self.clean_url(&url));
84 }
85 } else {
86 return Some(self.clean_url(&url));
87 }
88 }
89
90 None
91 }
92
93 fn clean_url(&self, url: &Url) -> Url {
95 if !self.include_fragment && url.fragment().is_some() {
96 let mut cleaned = url.clone();
97 cleaned.set_fragment(None);
98 cleaned
99 } else {
100 url.clone()
101 }
102 }
103
104 fn request_key(&self, request: &Request) -> String {
106 format!("{}:{}", request.method, request.url)
108 }
109}
110
111#[async_trait]
112impl<C: Send + Sync> Middleware<C> for RefererMiddleware {
113 fn name(&self) -> &str {
114 "RefererMiddleware"
115 }
116
117 async fn process_request(
118 &self,
119 _client: &C,
120 mut request: Request,
121 ) -> Result<MiddlewareAction<Request>, SpiderError> {
122 let referer = self.get_referer_for_request(&request);
123 let referer = if let Some(ref_from_meta) = referer {
124 Some(ref_from_meta)
125 } else {
126 let request_key = self.request_key(&request);
127 self.referer_map
128 .get(&request_key)
129 .map(|entry| entry.value().clone())
130 };
131
132 let referer = if let Some(ref_url) = referer {
133 Some(ref_url)
134 } else if let Some(parent_id) = request.get_meta_ref("parent_request_id") {
135 if let Some(parent_id_str) = parent_id.value().as_str() {
136 self.referer_map
137 .get(parent_id_str)
138 .map(|entry| entry.value().clone())
139 } else {
140 None
141 }
142 } else {
143 None
144 };
145
146 if let Some(referer) = referer {
147 match HeaderValue::from_str(referer.as_str()) {
148 Ok(header_value) => {
149 request.headers.insert(REFERER, header_value);
150 debug!(
151 "Set Referer header to: {} for request: {}",
152 referer, request.url
153 );
154 }
155 Err(e) => {
156 debug!("Failed to set Referer header: {}", e);
157 }
158 }
159 }
160
161 Ok(MiddlewareAction::Continue(request))
162 }
163
164 async fn process_response(
165 &self,
166 response: Response,
167 ) -> Result<MiddlewareAction<Response>, SpiderError> {
168 let response_url = response.url.clone();
169 let request = response.request_from_response();
170 let request_id = format!("req_{:x}", seahash::hash(request.url.as_str().as_bytes()));
171
172 let request_key = self.request_key(&request);
177 let cleaned_url = self.clean_url(&response_url);
178
179 if self.referer_map.len() < self.max_chain_length {
180 self.referer_map.insert(request_key, cleaned_url.clone());
181 self.referer_map.insert(request_id.clone(), cleaned_url);
182
183 debug!(
184 "Stored referer mapping for request {}: {}",
185 request.url, response_url
186 );
187 }
188
189 Ok(MiddlewareAction::Continue(response))
190 }
191}