1use async_trait::async_trait;
8use dashmap::DashMap;
9use moka::sync::Cache;
10use reqwest::header::{HeaderValue, USER_AGENT};
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12use std::fmt::Debug;
13use std::fs::File;
14use std::io::{BufRead, BufReader};
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::time::Duration;
19use tracing::{debug, info, warn};
20use ua_generator::ua::*;
21
22use rand::seq::SliceRandom;
23
24use spider_util::error::SpiderError;
25use crate::middleware::{Middleware, MiddlewareAction};
26use spider_util::request::Request;
27
28#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
30pub enum UserAgentRotationStrategy {
31 #[default]
33 Random,
34 Sequential,
36 Sticky,
38 StickySession,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub enum BuiltinUserAgentList {
45 Chrome,
47 ChromeLinux,
49 ChromeMac,
51 ChromeMobile,
53 ChromeTablet,
55 ChromeWindows,
57 Firefox,
59 FirefoxLinux,
61 FirefoxMac,
63 FirefoxMobile,
65 FirefoxTablet,
67 FirefoxWindows,
69 Safari,
71 SafariMac,
73 SafariMobile,
75 SafariTablet,
77 SafariWindows,
79 Random,
81}
82
83#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
85#[serde(untagged)]
86pub enum UserAgentSource {
87 List(Vec<String>),
89 File(PathBuf),
91 Builtin(BuiltinUserAgentList),
93 None,
95}
96
97impl Default for UserAgentSource {
98 fn default() -> Self {
99 UserAgentSource::Builtin(BuiltinUserAgentList::Random)
100 }
101}
102
103fn serialize_arc_string<S>(x: &Arc<String>, s: S) -> Result<S::Ok, S::Error>
105where
106 S: Serializer,
107{
108 s.serialize_str(x.as_str())
109}
110
111fn deserialize_arc_string<'de, D>(deserializer: D) -> Result<Arc<String>, D::Error>
113where
114 D: Deserializer<'de>,
115{
116 let s = String::deserialize(deserializer)?;
117 Ok(Arc::new(s))
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct UserAgentProfile {
123 #[serde(serialize_with = "serialize_arc_string", deserialize_with = "deserialize_arc_string")]
125 pub user_agent: Arc<String>,
126 #[serde(default)]
128 pub headers: DashMap<String, String>,
129}
130
131impl From<String> for UserAgentProfile {
132 fn from(user_agent: String) -> Self {
133 UserAgentProfile {
134 user_agent: Arc::new(user_agent),
135 headers: DashMap::new(),
136 }
137 }
138}
139
140impl From<&str> for UserAgentProfile {
141 fn from(user_agent: &str) -> Self {
142 UserAgentProfile {
143 user_agent: Arc::new(user_agent.to_string()),
144 headers: DashMap::new(),
145 }
146 }
147}
148
149#[derive(Debug, Clone, Default, Serialize, Deserialize)]
151pub struct UserAgentMiddlewareBuilder {
152 source: UserAgentSource,
153 strategy: UserAgentRotationStrategy,
154 fallback_user_agent: Option<String>,
155 per_domain_source: DashMap<String, UserAgentSource>,
156 per_domain_strategy: DashMap<String, UserAgentRotationStrategy>,
157 session_duration: Option<Duration>,
158}
159
160impl UserAgentMiddlewareBuilder {
161 pub fn source(mut self, source: UserAgentSource) -> Self {
163 self.source = source;
164 self
165 }
166
167 pub fn strategy(mut self, strategy: UserAgentRotationStrategy) -> Self {
169 self.strategy = strategy;
170 self
171 }
172
173 pub fn session_duration(mut self, duration: Duration) -> Self {
175 self.session_duration = Some(duration);
176 self
177 }
178
179 pub fn fallback_user_agent(mut self, fallback_user_agent: String) -> Self {
181 self.fallback_user_agent = Some(fallback_user_agent);
182 self
183 }
184
185 pub fn per_domain_source(self, domain: String, source: UserAgentSource) -> Self {
187 self.per_domain_source.insert(domain, source);
188 self
189 }
190
191 pub fn per_domain_strategy(self, domain: String, strategy: UserAgentRotationStrategy) -> Self {
193 self.per_domain_strategy.insert(domain, strategy);
194 self
195 }
196
197 pub fn build(self) -> Result<UserAgentMiddleware, SpiderError> {
200 let default_pool = Arc::new(UserAgentMiddleware::load_user_agents(&self.source)?);
201
202 let domain_cache = Cache::builder()
203 .time_to_live(Duration::from_secs(30 * 60)) .build();
205
206 for entry in self.per_domain_source.iter() {
207 let domain = entry.key().clone();
208 let source = entry.value().clone();
209 let pool = Arc::new(UserAgentMiddleware::load_user_agents(&source)?);
210 domain_cache.insert(domain, pool);
211 }
212
213 let session_cache = Cache::builder()
214 .time_to_live(self.session_duration.unwrap_or(Duration::from_secs(5 * 60)))
215 .build();
216
217 let middleware = UserAgentMiddleware {
218 strategy: self.strategy,
219 fallback_user_agent: self.fallback_user_agent,
220 domain_cache,
221 default_pool,
222 sticky_cache: DashMap::new(),
223 session_cache,
224 per_domain_strategy: self.per_domain_strategy,
225 current_index: AtomicUsize::new(0),
226 };
227
228 info!(
229 "Initializing UserAgentMiddleware with config: {:?}",
230 middleware
231 );
232
233 Ok(middleware)
234 }
235}
236
237pub struct UserAgentMiddleware {
238 strategy: UserAgentRotationStrategy,
239 fallback_user_agent: Option<String>,
240 domain_cache: Cache<String, Arc<Vec<UserAgentProfile>>>,
241 default_pool: Arc<Vec<UserAgentProfile>>,
242 sticky_cache: DashMap<String, UserAgentProfile>,
243 session_cache: Cache<String, UserAgentProfile>,
244 per_domain_strategy: DashMap<String, UserAgentRotationStrategy>,
245 current_index: AtomicUsize,
246}
247
248impl Debug for UserAgentMiddleware {
249 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250 f.debug_struct("UserAgentMiddleware")
251 .field("strategy", &self.strategy)
252 .field("fallback_user_agent", &self.fallback_user_agent)
253 .field(
254 "domain_cache",
255 &format!("Cache({})", self.domain_cache.weighted_size()),
256 )
257 .field(
258 "default_pool",
259 &format!("Pool({})", self.default_pool.len()),
260 )
261 .field(
262 "sticky_cache",
263 &format!("DashMap({})", self.sticky_cache.len()),
264 )
265 .field(
266 "session_cache",
267 &format!("Cache({})", self.session_cache.weighted_size()),
268 )
269 .field(
270 "per_domain_strategy",
271 &format!("DashMap({})", self.per_domain_strategy.len()),
272 )
273 .field("current_index", &self.current_index)
274 .finish()
275 }
276}
277
278impl UserAgentMiddleware {
279 pub fn builder() -> UserAgentMiddlewareBuilder {
281 UserAgentMiddlewareBuilder::default()
282 }
283
284 fn load_user_agents(source: &UserAgentSource) -> Result<Vec<UserAgentProfile>, SpiderError> {
285 match source {
286 UserAgentSource::List(list) => Ok(list
287 .iter()
288 .map(|ua| UserAgentProfile::from(ua.clone()))
289 .collect()),
290 UserAgentSource::File(path) => Self::load_from_file(path),
291 UserAgentSource::Builtin(builtin_list) => {
292 Ok(Self::load_builtin_user_agents(builtin_list))
293 }
294 UserAgentSource::None => Ok(Vec::new()),
295 }
296 }
297
298 fn load_from_file(path: &Path) -> Result<Vec<UserAgentProfile>, SpiderError> {
299 if !path.exists() {
300 return Err(SpiderError::IoError(
301 std::io::Error::new(
302 std::io::ErrorKind::NotFound,
303 format!("User-agent file not found: {}", path.display()),
304 )
305 .to_string(),
306 ));
307 }
308 let file = File::open(path)?;
309 let reader = BufReader::new(file);
310 let user_agents: Vec<UserAgentProfile> = reader
311 .lines()
312 .map_while(Result::ok)
313 .filter(|line| !line.trim().is_empty())
314 .map(UserAgentProfile::from)
315 .collect();
316
317 if user_agents.is_empty() {
318 warn!(
319 "User-Agent file {:?} is empty or contains no valid User-Agents.",
320 path
321 );
322 }
323 Ok(user_agents)
324 }
325
326 fn load_builtin_user_agents(list_type: &BuiltinUserAgentList) -> Vec<UserAgentProfile> {
327 let ua = match list_type {
328 BuiltinUserAgentList::Chrome => STATIC_CHROME_AGENTS,
329 BuiltinUserAgentList::ChromeLinux => STATIC_CHROME_LINUX_AGENTS,
330 BuiltinUserAgentList::ChromeMac => STATIC_CHROME_MAC_AGENTS,
331 BuiltinUserAgentList::ChromeMobile => STATIC_CHROME_MOBILE_AGENTS,
332 BuiltinUserAgentList::ChromeTablet => STATIC_CHROME_TABLET_AGENTS,
333 BuiltinUserAgentList::ChromeWindows => STATIC_CHROME_WINDOWS_AGENTS,
334 BuiltinUserAgentList::Firefox => STATIC_FIREFOX_AGENTS,
335 BuiltinUserAgentList::FirefoxLinux => STATIC_FIREFOX_LINUX_AGENTS,
336 BuiltinUserAgentList::FirefoxMac => STATIC_FIREFOX_MAC_AGENTS,
337 BuiltinUserAgentList::FirefoxMobile => STATIC_FIREFOX_MOBILE_AGENTS,
338 BuiltinUserAgentList::FirefoxTablet => STATIC_FIREFOX_TABLET_AGENTS,
339 BuiltinUserAgentList::FirefoxWindows => STATIC_FIREFOX_WINDOWS_AGENTS,
340 BuiltinUserAgentList::Safari => STATIC_SAFARI_AGENTS,
341 BuiltinUserAgentList::SafariMac => STATIC_SAFARI_MAC_AGENTS,
342 BuiltinUserAgentList::SafariMobile => STATIC_SAFARI_MOBILE_AGENTS,
343 BuiltinUserAgentList::SafariTablet => STATIC_SAFARI_TABLET_AGENTS,
344 BuiltinUserAgentList::SafariWindows => STATIC_FIREFOX_WINDOWS_AGENTS,
345 BuiltinUserAgentList::Random => all_static_agents(),
346 };
347
348 ua.iter().map(|&v| UserAgentProfile::from(v)).collect()
349 }
350
351 fn get_user_agent(&self, domain: Option<&str>) -> Option<UserAgentProfile> {
352 let mut rng = rand::thread_rng();
353
354 let domain_str = domain.unwrap_or_default().to_string();
355
356 let strategy = self
357 .per_domain_strategy
358 .get(&domain_str)
359 .map(|s| s.value().clone())
360 .unwrap_or_else(|| self.strategy.clone());
361
362 let pool = || {
363 domain
364 .and_then(|d| self.domain_cache.get(d))
365 .unwrap_or_else(|| self.default_pool.clone())
366 };
367
368 let get_fallback = || {
369 debug!("User-Agent pool is empty or no UA selected.");
370 self.fallback_user_agent
371 .as_ref()
372 .map(|ua| UserAgentProfile::from(ua.clone()))
373 };
374
375 match strategy {
376 UserAgentRotationStrategy::Random => {
377 let p = pool();
378 if p.is_empty() {
379 return get_fallback();
380 }
381 p.choose(&mut rng).cloned()
382 }
383 UserAgentRotationStrategy::Sequential => {
384 let p = pool();
385 if p.is_empty() {
386 return get_fallback();
387 }
388 let current = self.current_index.fetch_add(1, Ordering::SeqCst);
389 let index = current % p.len();
390 p.get(index).cloned()
391 }
392 UserAgentRotationStrategy::Sticky => {
393 if let Some(profile) = self.sticky_cache.get(&domain_str) {
394 return Some(profile.clone());
395 }
396
397 let p = pool();
398 if p.is_empty() {
399 return get_fallback();
400 }
401
402 if let Some(profile) = p.choose(&mut rng).cloned() {
403 self.sticky_cache.insert(domain_str, profile.clone());
404 Some(profile)
405 } else {
406 get_fallback()
407 }
408 }
409 UserAgentRotationStrategy::StickySession => {
410 if let Some(profile) = self.session_cache.get(&domain_str) {
411 return Some(profile);
412 }
413
414 let p = pool();
415 if p.is_empty() {
416 return get_fallback();
417 }
418
419 if let Some(profile) = p.choose(&mut rng).cloned() {
420 self.session_cache.insert(domain_str, profile.clone());
421 Some(profile)
422 } else {
423 get_fallback()
424 }
425 }
426 }
427 }
428}
429
430#[async_trait]
431impl<C: Send + Sync> Middleware<C> for UserAgentMiddleware {
432 fn name(&self) -> &str {
433 "UserAgentMiddleware"
434 }
435
436 async fn process_request(
437 &mut self,
438 _client: &C,
439 mut request: Request,
440 ) -> Result<MiddlewareAction<Request>, SpiderError> {
441 let domain = request.url.domain();
442 if let Some(profile) = self.get_user_agent(domain) {
443 debug!("Applying User-Agent: {}", profile.user_agent);
444 request.headers.insert(
445 USER_AGENT,
446 HeaderValue::from_str(&profile.user_agent).map_err(|e| {
447 SpiderError::HeaderValueError(format!(
448 "Invalid User-Agent string '{}': {}",
449 profile.user_agent, e
450 ))
451 })?,
452 );
453 for header in profile.headers.iter() {
454 request.headers.insert(
455 reqwest::header::HeaderName::from_bytes(header.key().as_bytes()).map_err(
456 |e| SpiderError::HeaderValueError(format!("Invalid header name: {}", e)),
457 )?,
458 HeaderValue::from_str(header.value().as_str()).map_err(|e| {
459 SpiderError::HeaderValueError(format!(
460 "Invalid header value for {}: {}",
461 header.key(),
462 e
463 ))
464 })?,
465 );
466 }
467 } else {
468 debug!("No User-Agent applied.");
469 }
470 Ok(MiddlewareAction::Continue(request))
471 }
472}