spider_lib/middlewares/
referer.rs1use async_trait::async_trait;
2use dashmap::DashMap;
3use reqwest::header::{HeaderValue, REFERER};
4use std::sync::Arc;
5use url::Url;
6
7use crate::error::SpiderError;
8use crate::middleware::{Middleware, MiddlewareAction};
9use crate::request::Request;
10use crate::response::Response;
11use tracing::{debug, info};
12
13#[derive(Debug, Clone)]
16pub struct RefererMiddleware {
17 pub same_origin_only: bool,
19 pub max_chain_length: usize,
21 pub include_fragment: bool,
23 referer_map: Arc<DashMap<String, Url>>,
25}
26
27impl Default for RefererMiddleware {
28 fn default() -> Self {
29 let middleware = RefererMiddleware {
30 same_origin_only: true,
31 max_chain_length: 1000,
32 include_fragment: false,
33 referer_map: Arc::new(DashMap::new()),
34 };
35 info!(
36 "Initializing RefererMiddleware with config: {:?}",
37 middleware
38 );
39 middleware
40 }
41}
42
43impl RefererMiddleware {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn same_origin_only(mut self, same_origin_only: bool) -> Self {
51 self.same_origin_only = same_origin_only;
52 self
53 }
54
55 pub fn max_chain_length(mut self, max_chain_length: usize) -> Self {
57 self.max_chain_length = max_chain_length;
58 self
59 }
60
61 pub fn include_fragment(mut self, include_fragment: bool) -> Self {
63 self.include_fragment = include_fragment;
64 self
65 }
66
67 fn get_referer_for_request(&self, request: &Request) -> Option<Url> {
69 if let Some(referer_value) = request.meta.get("referer")
70 && let Some(referer_str) = referer_value.value().as_str()
71 && let Ok(url) = Url::parse(referer_str)
72 {
73 if self.same_origin_only {
74 let request_origin = format!(
75 "{}://{}",
76 request.url.scheme(),
77 request.url.host_str().unwrap_or("")
78 );
79 let referer_origin = format!("{}://{}", url.scheme(), url.host_str().unwrap_or(""));
80
81 if request_origin == referer_origin {
82 return Some(self.clean_url(&url));
83 }
84 } else {
85 return Some(self.clean_url(&url));
86 }
87 }
88
89 None
90 }
91
92 fn clean_url(&self, url: &Url) -> Url {
94 if !self.include_fragment && url.fragment().is_some() {
95 let mut cleaned = url.clone();
96 cleaned.set_fragment(None);
97 cleaned
98 } else {
99 url.clone()
100 }
101 }
102
103 fn request_key(&self, request: &Request) -> String {
105 format!("{}:{}", request.method, request.url)
107 }
108}
109
110#[async_trait]
111impl<C: Send + Sync> Middleware<C> for RefererMiddleware {
112 fn name(&self) -> &str {
113 "RefererMiddleware"
114 }
115
116 async fn process_request(
117 &mut self,
118 _client: &C,
119 mut request: Request,
120 ) -> Result<MiddlewareAction<Request>, SpiderError> {
121 let referer = self.get_referer_for_request(&request);
122 let referer = if let Some(ref_from_meta) = referer {
123 Some(ref_from_meta)
124 } else {
125 let request_key = self.request_key(&request);
126 self.referer_map
127 .get(&request_key)
128 .map(|entry| entry.value().clone())
129 };
130
131 let referer = if let Some(ref_url) = referer {
132 Some(ref_url)
133 } else if let Some(parent_id) = request.meta.get("parent_request_id") {
134 if let Some(parent_id_str) = parent_id.value().as_str() {
135 self.referer_map
136 .get(parent_id_str)
137 .map(|entry| entry.value().clone())
138 } else {
139 None
140 }
141 } else {
142 None
143 };
144
145 if let Some(referer) = referer {
146 match HeaderValue::from_str(referer.as_str()) {
147 Ok(header_value) => {
148 request.headers.insert(REFERER, header_value);
149 debug!(
150 "Set Referer header to: {} for request: {}",
151 referer, request.url
152 );
153 }
154 Err(e) => {
155 debug!("Failed to set Referer header: {}", e);
156 }
157 }
158 }
159
160 Ok(MiddlewareAction::Continue(request))
161 }
162
163 async fn process_response(
164 &mut self,
165 response: Response,
166 ) -> Result<MiddlewareAction<Response>, SpiderError> {
167 let response_url = response.url.clone();
168 let request = response.request_from_response();
169 let request_id = format!("req_{:x}", seahash::hash(request.url.as_str().as_bytes()));
170
171 let request_key = self.request_key(&request);
176 let cleaned_url = self.clean_url(&response_url);
177
178 if self.referer_map.len() < self.max_chain_length {
179 self.referer_map.insert(request_key, cleaned_url.clone());
180 self.referer_map.insert(request_id.clone(), cleaned_url);
181
182 debug!(
183 "Stored referer mapping for request {}: {}",
184 request.url, response_url
185 );
186 }
187
188 Ok(MiddlewareAction::Continue(response))
189 }
190}