Skip to main content

spider_middleware/
user_agent.rs

1//! User-Agent Middleware for rotating User-Agents during crawling.
2//!
3//! This module provides `UserAgentMiddleware` for managing and rotating User-Agent strings for
4//! outgoing requests. It supports various rotation strategies and allows for detailed
5//! configuration on a per-domain basis.
6
7use async_trait::async_trait;
8use dashmap::DashMap;
9use moka::sync::Cache;
10use reqwest::header::{HeaderValue, USER_AGENT};
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12use std::fmt::Debug;
13use std::fs::File;
14use std::io::{BufRead, BufReader};
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::time::Duration;
19use tracing::{debug, info, warn};
20use ua_generator::ua::*;
21
22use rand::seq::SliceRandom;
23
24use spider_util::error::SpiderError;
25use crate::middleware::{Middleware, MiddlewareAction};
26use spider_util::request::Request;
27
28/// Defines the strategy for rotating User-Agents.
29#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
30pub enum UserAgentRotationStrategy {
31    /// Randomly selects a User-Agent from the available pool.
32    #[default]
33    Random,
34    /// Sequentially cycles through the available User-Agents.
35    Sequential,
36    /// Selects a User-Agent on first encounter with a domain and uses it for all subsequent requests to that domain.
37    Sticky,
38    /// Selects a User-Agent on first encounter with a domain and uses it for a configured duration (session).
39    StickySession,
40}
41
42/// Predefined lists of User-Agents for common scenarios.
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub enum BuiltinUserAgentList {
45    /// Generic Chrome User-Agents.
46    Chrome,
47    /// Chrome User-Agents on Linux.
48    ChromeLinux,
49    /// Chrome User-Agents on Mac.
50    ChromeMac,
51    /// Chrome Mobile User-Agents.
52    ChromeMobile,
53    /// Chrome Tablet User-Agents.
54    ChromeTablet,
55    /// Chrome User-Agents on Windows.
56    ChromeWindows,
57    /// Generic Firefox User-Agents.
58    Firefox,
59    /// Firefox User-Agents on Linux.
60    FirefoxLinux,
61    /// Firefox User-Agents on Mac.
62    FirefoxMac,
63    /// Firefox Mobile User-Agents.
64    FirefoxMobile,
65    /// Firefox Tablet User-Agents.
66    FirefoxTablet,
67    /// Firefox User-Agents on Windows.
68    FirefoxWindows,
69    /// Generic Safari User-Agents.
70    Safari,
71    /// Safari User-Agents on Mac.
72    SafariMac,
73    /// Safari Mobile User-Agents.
74    SafariMobile,
75    /// Safari Tablet User-Agents.
76    SafariTablet,
77    /// Safari User-Agents on Windows.
78    SafariWindows,
79    /// A random selection from all available User-Agents.
80    Random,
81}
82
83/// Defines the source from which User-Agents are loaded.
84#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
85#[serde(untagged)]
86pub enum UserAgentSource {
87    /// A direct list of User-Agent strings.
88    List(Vec<String>),
89    /// Path to a file containing User-Agent strings, one per line.
90    File(PathBuf),
91    /// Use a predefined, built-in list of User-Agents.
92    Builtin(BuiltinUserAgentList),
93    /// No User-Agent source specified, will fallback to a default if available.
94    None,
95}
96
97impl Default for UserAgentSource {
98    fn default() -> Self {
99        UserAgentSource::Builtin(BuiltinUserAgentList::Random)
100    }
101}
102
103/// Custom serializer for Arc<String>
104fn serialize_arc_string<S>(x: &Arc<String>, s: S) -> Result<S::Ok, S::Error>
105where
106    S: Serializer,
107{
108    s.serialize_str(x.as_str())
109}
110
111/// Custom deserializer for Arc<String>
112fn deserialize_arc_string<'de, D>(deserializer: D) -> Result<Arc<String>, D::Error>
113where
114    D: Deserializer<'de>,
115{
116    let s = String::deserialize(deserializer)?;
117    Ok(Arc::new(s))
118}
119
120/// Represents a User-Agent profile, including the User-Agent string and other associated headers.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct UserAgentProfile {
123    /// The User-Agent string.
124    #[serde(serialize_with = "serialize_arc_string", deserialize_with = "deserialize_arc_string")]
125    pub user_agent: Arc<String>,
126    /// Additional headers that should be sent with this User-Agent to mimic a real browser.
127    #[serde(default)]
128    pub headers: DashMap<String, String>,
129}
130
131impl From<String> for UserAgentProfile {
132    fn from(user_agent: String) -> Self {
133        UserAgentProfile {
134            user_agent: Arc::new(user_agent),
135            headers: DashMap::new(),
136        }
137    }
138}
139
140impl From<&str> for UserAgentProfile {
141    fn from(user_agent: &str) -> Self {
142        UserAgentProfile {
143            user_agent: Arc::new(user_agent.to_string()),
144            headers: DashMap::new(),
145        }
146    }
147}
148
149/// Builder for creating a `UserAgentMiddleware`.
150#[derive(Debug, Clone, Default, Serialize, Deserialize)]
151pub struct UserAgentMiddlewareBuilder {
152    source: UserAgentSource,
153    strategy: UserAgentRotationStrategy,
154    fallback_user_agent: Option<String>,
155    per_domain_source: DashMap<String, UserAgentSource>,
156    per_domain_strategy: DashMap<String, UserAgentRotationStrategy>,
157    session_duration: Option<Duration>,
158}
159
160impl UserAgentMiddlewareBuilder {
161    /// Sets the primary source for User-Agents.
162    pub fn source(mut self, source: UserAgentSource) -> Self {
163        self.source = source;
164        self
165    }
166
167    /// Sets the default strategy to use for rotating User-Agents.
168    pub fn strategy(mut self, strategy: UserAgentRotationStrategy) -> Self {
169        self.strategy = strategy;
170        self
171    }
172
173    /// Sets the duration for a "sticky session" in the `StickySession` strategy.
174    pub fn session_duration(mut self, duration: Duration) -> Self {
175        self.session_duration = Some(duration);
176        self
177    }
178
179    /// Sets a fallback User-Agent to use if no other User-Agents are available.
180    pub fn fallback_user_agent(mut self, fallback_user_agent: String) -> Self {
181        self.fallback_user_agent = Some(fallback_user_agent);
182        self
183    }
184
185    /// Adds a domain-specific User-Agent source.
186    pub fn per_domain_source(self, domain: String, source: UserAgentSource) -> Self {
187        self.per_domain_source.insert(domain, source);
188        self
189    }
190
191    /// Adds a domain-specific User-Agent rotation strategy, overriding the default.
192    pub fn per_domain_strategy(self, domain: String, strategy: UserAgentRotationStrategy) -> Self {
193        self.per_domain_strategy.insert(domain, strategy);
194        self
195    }
196
197    /// Builds the `UserAgentMiddleware`.
198    /// This can fail if a User-Agent source file is specified but cannot be read.
199    pub fn build(self) -> Result<UserAgentMiddleware, SpiderError> {
200        let default_pool = Arc::new(UserAgentMiddleware::load_user_agents(&self.source)?);
201
202        let domain_cache = Cache::builder()
203            .time_to_live(Duration::from_secs(30 * 60)) // 30 minutes
204            .build();
205
206        for entry in self.per_domain_source.iter() {
207            let domain = entry.key().clone();
208            let source = entry.value().clone();
209            let pool = Arc::new(UserAgentMiddleware::load_user_agents(&source)?);
210            domain_cache.insert(domain, pool);
211        }
212
213        let session_cache = Cache::builder()
214            .time_to_live(self.session_duration.unwrap_or(Duration::from_secs(5 * 60)))
215            .build();
216
217        let middleware = UserAgentMiddleware {
218            strategy: self.strategy,
219            fallback_user_agent: self.fallback_user_agent,
220            domain_cache,
221            default_pool,
222            sticky_cache: DashMap::new(),
223            session_cache,
224            per_domain_strategy: self.per_domain_strategy,
225            current_index: AtomicUsize::new(0),
226        };
227
228        info!(
229            "Initializing UserAgentMiddleware with config: {:?}",
230            middleware
231        );
232
233        Ok(middleware)
234    }
235}
236
237pub struct UserAgentMiddleware {
238    strategy: UserAgentRotationStrategy,
239    fallback_user_agent: Option<String>,
240    domain_cache: Cache<String, Arc<Vec<UserAgentProfile>>>,
241    default_pool: Arc<Vec<UserAgentProfile>>,
242    sticky_cache: DashMap<String, UserAgentProfile>,
243    session_cache: Cache<String, UserAgentProfile>,
244    per_domain_strategy: DashMap<String, UserAgentRotationStrategy>,
245    current_index: AtomicUsize,
246}
247
248impl Debug for UserAgentMiddleware {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        f.debug_struct("UserAgentMiddleware")
251            .field("strategy", &self.strategy)
252            .field("fallback_user_agent", &self.fallback_user_agent)
253            .field(
254                "domain_cache",
255                &format!("Cache({})", self.domain_cache.weighted_size()),
256            )
257            .field(
258                "default_pool",
259                &format!("Pool({})", self.default_pool.len()),
260            )
261            .field(
262                "sticky_cache",
263                &format!("DashMap({})", self.sticky_cache.len()),
264            )
265            .field(
266                "session_cache",
267                &format!("Cache({})", self.session_cache.weighted_size()),
268            )
269            .field(
270                "per_domain_strategy",
271                &format!("DashMap({})", self.per_domain_strategy.len()),
272            )
273            .field("current_index", &self.current_index)
274            .finish()
275    }
276}
277
278impl UserAgentMiddleware {
279    /// Creates a new `UserAgentMiddlewareBuilder` to start building a `UserAgentMiddleware`.
280    pub fn builder() -> UserAgentMiddlewareBuilder {
281        UserAgentMiddlewareBuilder::default()
282    }
283
284    fn load_user_agents(source: &UserAgentSource) -> Result<Vec<UserAgentProfile>, SpiderError> {
285        match source {
286            UserAgentSource::List(list) => Ok(list
287                .iter()
288                .map(|ua| UserAgentProfile::from(ua.clone()))
289                .collect()),
290            UserAgentSource::File(path) => Self::load_from_file(path),
291            UserAgentSource::Builtin(builtin_list) => {
292                Ok(Self::load_builtin_user_agents(builtin_list))
293            }
294            UserAgentSource::None => Ok(Vec::new()),
295        }
296    }
297
298    fn load_from_file(path: &Path) -> Result<Vec<UserAgentProfile>, SpiderError> {
299        if !path.exists() {
300            return Err(SpiderError::IoError(
301                std::io::Error::new(
302                    std::io::ErrorKind::NotFound,
303                    format!("User-agent file not found: {}", path.display()),
304                )
305                .to_string(),
306            ));
307        }
308        let file = File::open(path)?;
309        let reader = BufReader::new(file);
310        let user_agents: Vec<UserAgentProfile> = reader
311            .lines()
312            .map_while(Result::ok)
313            .filter(|line| !line.trim().is_empty())
314            .map(UserAgentProfile::from)
315            .collect();
316
317        if user_agents.is_empty() {
318            warn!(
319                "User-Agent file {:?} is empty or contains no valid User-Agents.",
320                path
321            );
322        }
323        Ok(user_agents)
324    }
325
326    fn load_builtin_user_agents(list_type: &BuiltinUserAgentList) -> Vec<UserAgentProfile> {
327        let ua = match list_type {
328            BuiltinUserAgentList::Chrome => STATIC_CHROME_AGENTS,
329            BuiltinUserAgentList::ChromeLinux => STATIC_CHROME_LINUX_AGENTS,
330            BuiltinUserAgentList::ChromeMac => STATIC_CHROME_MAC_AGENTS,
331            BuiltinUserAgentList::ChromeMobile => STATIC_CHROME_MOBILE_AGENTS,
332            BuiltinUserAgentList::ChromeTablet => STATIC_CHROME_TABLET_AGENTS,
333            BuiltinUserAgentList::ChromeWindows => STATIC_CHROME_WINDOWS_AGENTS,
334            BuiltinUserAgentList::Firefox => STATIC_FIREFOX_AGENTS,
335            BuiltinUserAgentList::FirefoxLinux => STATIC_FIREFOX_LINUX_AGENTS,
336            BuiltinUserAgentList::FirefoxMac => STATIC_FIREFOX_MAC_AGENTS,
337            BuiltinUserAgentList::FirefoxMobile => STATIC_FIREFOX_MOBILE_AGENTS,
338            BuiltinUserAgentList::FirefoxTablet => STATIC_FIREFOX_TABLET_AGENTS,
339            BuiltinUserAgentList::FirefoxWindows => STATIC_FIREFOX_WINDOWS_AGENTS,
340            BuiltinUserAgentList::Safari => STATIC_SAFARI_AGENTS,
341            BuiltinUserAgentList::SafariMac => STATIC_SAFARI_MAC_AGENTS,
342            BuiltinUserAgentList::SafariMobile => STATIC_SAFARI_MOBILE_AGENTS,
343            BuiltinUserAgentList::SafariTablet => STATIC_SAFARI_TABLET_AGENTS,
344            BuiltinUserAgentList::SafariWindows => STATIC_FIREFOX_WINDOWS_AGENTS,
345            BuiltinUserAgentList::Random => all_static_agents(),
346        };
347
348        ua.iter().map(|&v| UserAgentProfile::from(v)).collect()
349    }
350
351    fn get_user_agent(&self, domain: Option<&str>) -> Option<UserAgentProfile> {
352        let mut rng = rand::thread_rng();
353
354        let domain_str = domain.unwrap_or_default().to_string();
355
356        let strategy = self
357            .per_domain_strategy
358            .get(&domain_str)
359            .map(|s| s.value().clone())
360            .unwrap_or_else(|| self.strategy.clone());
361
362        let pool = || {
363            domain
364                .and_then(|d| self.domain_cache.get(d))
365                .unwrap_or_else(|| self.default_pool.clone())
366        };
367
368        let get_fallback = || {
369            debug!("User-Agent pool is empty or no UA selected.");
370            self.fallback_user_agent
371                .as_ref()
372                .map(|ua| UserAgentProfile::from(ua.clone()))
373        };
374
375        match strategy {
376            UserAgentRotationStrategy::Random => {
377                let p = pool();
378                if p.is_empty() {
379                    return get_fallback();
380                }
381                p.choose(&mut rng).cloned()
382            }
383            UserAgentRotationStrategy::Sequential => {
384                let p = pool();
385                if p.is_empty() {
386                    return get_fallback();
387                }
388                let current = self.current_index.fetch_add(1, Ordering::SeqCst);
389                let index = current % p.len();
390                p.get(index).cloned()
391            }
392            UserAgentRotationStrategy::Sticky => {
393                if let Some(profile) = self.sticky_cache.get(&domain_str) {
394                    return Some(profile.clone());
395                }
396
397                let p = pool();
398                if p.is_empty() {
399                    return get_fallback();
400                }
401
402                if let Some(profile) = p.choose(&mut rng).cloned() {
403                    self.sticky_cache.insert(domain_str, profile.clone());
404                    Some(profile)
405                } else {
406                    get_fallback()
407                }
408            }
409            UserAgentRotationStrategy::StickySession => {
410                if let Some(profile) = self.session_cache.get(&domain_str) {
411                    return Some(profile);
412                }
413
414                let p = pool();
415                if p.is_empty() {
416                    return get_fallback();
417                }
418
419                if let Some(profile) = p.choose(&mut rng).cloned() {
420                    self.session_cache.insert(domain_str, profile.clone());
421                    Some(profile)
422                } else {
423                    get_fallback()
424                }
425            }
426        }
427    }
428}
429
430#[async_trait]
431impl<C: Send + Sync> Middleware<C> for UserAgentMiddleware {
432    fn name(&self) -> &str {
433        "UserAgentMiddleware"
434    }
435
436    async fn process_request(
437        &mut self,
438        _client: &C,
439        mut request: Request,
440    ) -> Result<MiddlewareAction<Request>, SpiderError> {
441        let domain = request.url.domain();
442        if let Some(profile) = self.get_user_agent(domain) {
443            debug!("Applying User-Agent: {}", profile.user_agent);
444            request.headers.insert(
445                USER_AGENT,
446                HeaderValue::from_str(&profile.user_agent).map_err(|e| {
447                    SpiderError::HeaderValueError(format!(
448                        "Invalid User-Agent string '{}': {}",
449                        profile.user_agent, e
450                    ))
451                })?,
452            );
453            for header in profile.headers.iter() {
454                request.headers.insert(
455                    reqwest::header::HeaderName::from_bytes(header.key().as_bytes()).map_err(
456                        |e| SpiderError::HeaderValueError(format!("Invalid header name: {}", e)),
457                    )?,
458                    HeaderValue::from_str(header.value().as_str()).map_err(|e| {
459                        SpiderError::HeaderValueError(format!(
460                            "Invalid header value for {}: {}",
461                            header.key(),
462                            e
463                        ))
464                    })?,
465                );
466            }
467        } else {
468            debug!("No User-Agent applied.");
469        }
470        Ok(MiddlewareAction::Continue(request))
471    }
472}