1use async_trait::async_trait;
8use dashmap::DashMap;
9use moka::sync::Cache;
10use reqwest::header::{HeaderValue, USER_AGENT};
11use serde::{Deserialize, Serialize};
12use std::fmt::Debug;
13use std::fs::File;
14use std::io::{BufRead, BufReader};
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::time::Duration;
19use tracing::{debug, info, warn};
20use ua_generator::ua::*;
21
22use rand::seq::SliceRandom;
23
24use crate::error::SpiderError;
25use crate::middleware::{Middleware, MiddlewareAction};
26use crate::request::Request;
27
28#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
30pub enum UserAgentRotationStrategy {
31 #[default]
33 Random,
34 Sequential,
36 Sticky,
38 StickySession,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub enum BuiltinUserAgentList {
45 Chrome,
47 ChromeLinux,
49 ChromeMac,
51 ChromeMobile,
53 ChromeTablet,
55 ChromeWindows,
57 Firefox,
59 FirefoxLinux,
61 FirefoxMac,
63 FirefoxMobile,
65 FirefoxTablet,
67 FirefoxWindows,
69 Safari,
71 SafariMac,
73 SafariMobile,
75 SafariTablet,
77 SafariWindows,
79 Random,
81}
82
83#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
85#[serde(untagged)]
86pub enum UserAgentSource {
87 List(Vec<String>),
89 File(PathBuf),
91 Builtin(BuiltinUserAgentList),
93 None,
95}
96
97impl Default for UserAgentSource {
98 fn default() -> Self {
99 UserAgentSource::Builtin(BuiltinUserAgentList::Random)
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct UserAgentProfile {
106 pub user_agent: Arc<String>,
108 #[serde(default)]
110 pub headers: DashMap<String, String>,
111}
112
113impl From<String> for UserAgentProfile {
114 fn from(user_agent: String) -> Self {
115 UserAgentProfile {
116 user_agent: Arc::new(user_agent),
117 headers: DashMap::new(),
118 }
119 }
120}
121
122impl From<&str> for UserAgentProfile {
123 fn from(user_agent: &str) -> Self {
124 UserAgentProfile {
125 user_agent: Arc::new(user_agent.to_string()),
126 headers: DashMap::new(),
127 }
128 }
129}
130
131#[derive(Debug, Clone, Default, Serialize, Deserialize)]
133pub struct UserAgentMiddlewareBuilder {
134 source: UserAgentSource,
135 strategy: UserAgentRotationStrategy,
136 fallback_user_agent: Option<String>,
137 per_domain_source: DashMap<String, UserAgentSource>,
138 per_domain_strategy: DashMap<String, UserAgentRotationStrategy>,
139 session_duration: Option<Duration>,
140}
141
142impl UserAgentMiddlewareBuilder {
143 pub fn source(mut self, source: UserAgentSource) -> Self {
145 self.source = source;
146 self
147 }
148
149 pub fn strategy(mut self, strategy: UserAgentRotationStrategy) -> Self {
151 self.strategy = strategy;
152 self
153 }
154
155 pub fn session_duration(mut self, duration: Duration) -> Self {
157 self.session_duration = Some(duration);
158 self
159 }
160
161 pub fn fallback_user_agent(mut self, fallback_user_agent: String) -> Self {
163 self.fallback_user_agent = Some(fallback_user_agent);
164 self
165 }
166
167 pub fn per_domain_source(self, domain: String, source: UserAgentSource) -> Self {
169 self.per_domain_source.insert(domain, source);
170 self
171 }
172
173 pub fn per_domain_strategy(self, domain: String, strategy: UserAgentRotationStrategy) -> Self {
175 self.per_domain_strategy.insert(domain, strategy);
176 self
177 }
178
179 pub fn build(self) -> Result<UserAgentMiddleware, SpiderError> {
182 let default_pool = Arc::new(UserAgentMiddleware::load_user_agents(&self.source)?);
183
184 let domain_cache = Cache::builder()
185 .time_to_live(Duration::from_secs(30 * 60)) .build();
187
188 for entry in self.per_domain_source.iter() {
189 let domain = entry.key().clone();
190 let source = entry.value().clone();
191 let pool = Arc::new(UserAgentMiddleware::load_user_agents(&source)?);
192 domain_cache.insert(domain, pool);
193 }
194
195 let session_cache = Cache::builder()
196 .time_to_live(self.session_duration.unwrap_or(Duration::from_secs(5 * 60)))
197 .build();
198
199 let middleware = UserAgentMiddleware {
200 strategy: self.strategy,
201 fallback_user_agent: self.fallback_user_agent,
202 domain_cache,
203 default_pool,
204 sticky_cache: DashMap::new(),
205 session_cache,
206 per_domain_strategy: self.per_domain_strategy,
207 current_index: AtomicUsize::new(0),
208 };
209
210 info!(
211 "Initializing UserAgentMiddleware with config: {:?}",
212 middleware
213 );
214
215 Ok(middleware)
216 }
217}
218
219pub struct UserAgentMiddleware {
220 strategy: UserAgentRotationStrategy,
221 fallback_user_agent: Option<String>,
222 domain_cache: Cache<String, Arc<Vec<UserAgentProfile>>>,
223 default_pool: Arc<Vec<UserAgentProfile>>,
224 sticky_cache: DashMap<String, UserAgentProfile>,
225 session_cache: Cache<String, UserAgentProfile>,
226 per_domain_strategy: DashMap<String, UserAgentRotationStrategy>,
227 current_index: AtomicUsize,
228}
229
230impl Debug for UserAgentMiddleware {
231 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232 f.debug_struct("UserAgentMiddleware")
233 .field("strategy", &self.strategy)
234 .field("fallback_user_agent", &self.fallback_user_agent)
235 .field(
236 "domain_cache",
237 &format!("Cache({})", self.domain_cache.weighted_size()),
238 )
239 .field(
240 "default_pool",
241 &format!("Pool({})", self.default_pool.len()),
242 )
243 .field(
244 "sticky_cache",
245 &format!("DashMap({})", self.sticky_cache.len()),
246 )
247 .field(
248 "session_cache",
249 &format!("Cache({})", self.session_cache.weighted_size()),
250 )
251 .field(
252 "per_domain_strategy",
253 &format!("DashMap({})", self.per_domain_strategy.len()),
254 )
255 .field("current_index", &self.current_index)
256 .finish()
257 }
258}
259
260impl UserAgentMiddleware {
261 pub fn builder() -> UserAgentMiddlewareBuilder {
263 UserAgentMiddlewareBuilder::default()
264 }
265
266 fn load_user_agents(source: &UserAgentSource) -> Result<Vec<UserAgentProfile>, SpiderError> {
267 match source {
268 UserAgentSource::List(list) => Ok(list
269 .iter()
270 .map(|ua| UserAgentProfile::from(ua.clone()))
271 .collect()),
272 UserAgentSource::File(path) => Self::load_from_file(path),
273 UserAgentSource::Builtin(builtin_list) => {
274 Ok(Self::load_builtin_user_agents(builtin_list))
275 }
276 UserAgentSource::None => Ok(Vec::new()),
277 }
278 }
279
280 fn load_from_file(path: &Path) -> Result<Vec<UserAgentProfile>, SpiderError> {
281 if !path.exists() {
282 return Err(SpiderError::IoError(
283 std::io::Error::new(
284 std::io::ErrorKind::NotFound,
285 format!("User-agent file not found: {}", path.display()),
286 )
287 .to_string(),
288 ));
289 }
290 let file = File::open(path)?;
291 let reader = BufReader::new(file);
292 let user_agents: Vec<UserAgentProfile> = reader
293 .lines()
294 .map_while(Result::ok)
295 .filter(|line| !line.trim().is_empty())
296 .map(UserAgentProfile::from)
297 .collect();
298
299 if user_agents.is_empty() {
300 warn!(
301 "User-Agent file {:?} is empty or contains no valid User-Agents.",
302 path
303 );
304 }
305 Ok(user_agents)
306 }
307
308 fn load_builtin_user_agents(list_type: &BuiltinUserAgentList) -> Vec<UserAgentProfile> {
309 let ua = match list_type {
310 BuiltinUserAgentList::Chrome => STATIC_CHROME_AGENTS,
311 BuiltinUserAgentList::ChromeLinux => STATIC_CHROME_LINUX_AGENTS,
312 BuiltinUserAgentList::ChromeMac => STATIC_CHROME_MAC_AGENTS,
313 BuiltinUserAgentList::ChromeMobile => STATIC_CHROME_MOBILE_AGENTS,
314 BuiltinUserAgentList::ChromeTablet => STATIC_CHROME_TABLET_AGENTS,
315 BuiltinUserAgentList::ChromeWindows => STATIC_CHROME_WINDOWS_AGENTS,
316 BuiltinUserAgentList::Firefox => STATIC_FIREFOX_AGENTS,
317 BuiltinUserAgentList::FirefoxLinux => STATIC_FIREFOX_LINUX_AGENTS,
318 BuiltinUserAgentList::FirefoxMac => STATIC_FIREFOX_MAC_AGENTS,
319 BuiltinUserAgentList::FirefoxMobile => STATIC_FIREFOX_MOBILE_AGENTS,
320 BuiltinUserAgentList::FirefoxTablet => STATIC_FIREFOX_TABLET_AGENTS,
321 BuiltinUserAgentList::FirefoxWindows => STATIC_FIREFOX_WINDOWS_AGENTS,
322 BuiltinUserAgentList::Safari => STATIC_SAFARI_AGENTS,
323 BuiltinUserAgentList::SafariMac => STATIC_SAFARI_MAC_AGENTS,
324 BuiltinUserAgentList::SafariMobile => STATIC_SAFARI_MOBILE_AGENTS,
325 BuiltinUserAgentList::SafariTablet => STATIC_SAFARI_TABLET_AGENTS,
326 BuiltinUserAgentList::SafariWindows => STATIC_FIREFOX_WINDOWS_AGENTS,
327 BuiltinUserAgentList::Random => all_static_agents(),
328 };
329
330 ua.iter().map(|&v| UserAgentProfile::from(v)).collect()
331 }
332
333 fn get_user_agent(&self, domain: Option<&str>) -> Option<UserAgentProfile> {
334 let mut rng = rand::thread_rng();
335
336 let domain_str = domain.unwrap_or_default().to_string();
337
338 let strategy = self
339 .per_domain_strategy
340 .get(&domain_str)
341 .map(|s| s.value().clone())
342 .unwrap_or_else(|| self.strategy.clone());
343
344 let pool = || {
345 domain
346 .and_then(|d| self.domain_cache.get(d))
347 .unwrap_or_else(|| self.default_pool.clone())
348 };
349
350 let get_fallback = || {
351 debug!("User-Agent pool is empty or no UA selected.");
352 self.fallback_user_agent
353 .as_ref()
354 .map(|ua| UserAgentProfile::from(ua.clone()))
355 };
356
357 match strategy {
358 UserAgentRotationStrategy::Random => {
359 let p = pool();
360 if p.is_empty() {
361 return get_fallback();
362 }
363 p.choose(&mut rng).cloned()
364 }
365 UserAgentRotationStrategy::Sequential => {
366 let p = pool();
367 if p.is_empty() {
368 return get_fallback();
369 }
370 let current = self.current_index.fetch_add(1, Ordering::SeqCst);
371 let index = current % p.len();
372 p.get(index).cloned()
373 }
374 UserAgentRotationStrategy::Sticky => {
375 if let Some(profile) = self.sticky_cache.get(&domain_str) {
376 return Some(profile.clone());
377 }
378
379 let p = pool();
380 if p.is_empty() {
381 return get_fallback();
382 }
383
384 if let Some(profile) = p.choose(&mut rng).cloned() {
385 self.sticky_cache.insert(domain_str, profile.clone());
386 Some(profile)
387 } else {
388 get_fallback()
389 }
390 }
391 UserAgentRotationStrategy::StickySession => {
392 if let Some(profile) = self.session_cache.get(&domain_str) {
393 return Some(profile);
394 }
395
396 let p = pool();
397 if p.is_empty() {
398 return get_fallback();
399 }
400
401 if let Some(profile) = p.choose(&mut rng).cloned() {
402 self.session_cache.insert(domain_str, profile.clone());
403 Some(profile)
404 } else {
405 get_fallback()
406 }
407 }
408 }
409 }
410}
411
412#[async_trait]
413impl<C: Send + Sync> Middleware<C> for UserAgentMiddleware {
414 fn name(&self) -> &str {
415 "UserAgentMiddleware"
416 }
417
418 async fn process_request(
419 &mut self,
420 _client: &C,
421 mut request: Request,
422 ) -> Result<MiddlewareAction<Request>, SpiderError> {
423 let domain = request.url.domain();
424 if let Some(profile) = self.get_user_agent(domain) {
425 debug!("Applying User-Agent: {}", profile.user_agent);
426 request.headers.insert(
427 USER_AGENT,
428 HeaderValue::from_str(&profile.user_agent).map_err(|e| {
429 SpiderError::HeaderValueError(format!(
430 "Invalid User-Agent string '{}': {}",
431 profile.user_agent, e
432 ))
433 })?,
434 );
435 for header in profile.headers.iter() {
436 request.headers.insert(
437 reqwest::header::HeaderName::from_bytes(header.key().as_bytes()).map_err(
438 |e| SpiderError::HeaderValueError(format!("Invalid header name: {}", e)),
439 )?,
440 HeaderValue::from_str(header.value().as_str()).map_err(|e| {
441 SpiderError::HeaderValueError(format!(
442 "Invalid header value for {}: {}",
443 header.key(),
444 e
445 ))
446 })?,
447 );
448 }
449 } else {
450 debug!("No User-Agent applied.");
451 }
452 Ok(MiddlewareAction::Continue(request))
453 }
454}