Skip to main content

spider_lib/middlewares/
user_agent.rs

1use async_trait::async_trait;
2use dashmap::DashMap;
3use moka::sync::Cache;
4use reqwest::header::{HeaderValue, USER_AGENT};
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7use std::fs::File;
8use std::io::{BufRead, BufReader};
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::time::Duration;
13use tracing::{debug, info, warn};
14
15use rand::seq::SliceRandom;
16
17use crate::error::SpiderError;
18use crate::middleware::{Middleware, MiddlewareAction};
19use crate::request::Request;
20
21/// Defines the strategy for rotating User-Agents.
22#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
23pub enum UserAgentRotationStrategy {
24    /// Randomly selects a User-Agent from the available pool.
25    #[default]
26    Random,
27    /// Sequentially cycles through the available User-Agents.
28    Sequential,
29}
30
31/// Predefined lists of User-Agents for common scenarios.
32#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
33pub enum BuiltinUserAgentList {
34    /// A list of popular desktop browser User-Agents.
35    Desktop,
36    /// A list of popular mobile browser User-Agents.
37    Mobile,
38    /// A mixed list of popular desktop and mobile browser User-Agents.
39    RandomPopular,
40}
41
42/// Defines the source from which User-Agents are loaded.
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44#[serde(untagged)] // Allow flexible deserialization
45pub enum UserAgentSource {
46    /// A direct list of User-Agent strings.
47    List(Vec<String>),
48    /// Path to a file containing User-Agent strings, one per line.
49    File(PathBuf),
50    /// Use a predefined, built-in list of User-Agents.
51    Builtin(BuiltinUserAgentList),
52    /// No User-Agent source specified, will fallback to a default if available.
53    None,
54}
55
56impl Default for UserAgentSource {
57    fn default() -> Self {
58        // Default to a builtin popular list for convenience
59        UserAgentSource::Builtin(BuiltinUserAgentList::RandomPopular)
60    }
61}
62
63/// Represents a User-Agent profile, including the User-Agent string and other associated headers.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct UserAgentProfile {
66    /// The User-Agent string.
67    pub user_agent: Arc<String>,
68    /// Additional headers that should be sent with this User-Agent to mimic a real browser.
69    #[serde(default)]
70    pub headers: DashMap<String, String>,
71}
72
73impl From<String> for UserAgentProfile {
74    fn from(user_agent: String) -> Self {
75        UserAgentProfile {
76            user_agent: Arc::new(user_agent),
77            headers: DashMap::new(),
78        }
79    }
80}
81
82impl From<&str> for UserAgentProfile {
83    fn from(user_agent: &str) -> Self {
84        UserAgentProfile {
85            user_agent: Arc::new(user_agent.to_string()),
86            headers: DashMap::new(),
87        }
88    }
89}
90
91/// Builder for creating a `UserAgentMiddleware`.
92#[derive(Debug, Clone, Default, Serialize, Deserialize)]
93pub struct UserAgentMiddlewareBuilder {
94    source: UserAgentSource,
95    strategy: UserAgentRotationStrategy,
96    fallback_user_agent: Option<String>,
97    per_domain: DashMap<String, UserAgentSource>,
98}
99
100impl UserAgentMiddlewareBuilder {
101    /// Sets the primary source for User-Agents.
102    pub fn source(mut self, source: UserAgentSource) -> Self {
103        self.source = source;
104        self
105    }
106
107    /// Sets the strategy to use for rotating User-Agents.
108    pub fn strategy(mut self, strategy: UserAgentRotationStrategy) -> Self {
109        self.strategy = strategy;
110        self
111    }
112
113    /// Sets a fallback User-Agent to use if no other User-Agents are available.
114    pub fn fallback_user_agent(mut self, fallback_user_agent: String) -> Self {
115        self.fallback_user_agent = Some(fallback_user_agent);
116        self
117    }
118
119    /// Adds a domain-specific User-Agent source.
120    pub fn per_domain(self, domain: String, source: UserAgentSource) -> Self {
121        self.per_domain.insert(domain, source);
122        self
123    }
124
125    /// Builds the `UserAgentMiddleware`.
126    /// This can fail if a User-Agent source file is specified but cannot be read.
127    pub fn build(self) -> Result<UserAgentMiddleware, SpiderError> {
128        let default_pool = Arc::new(UserAgentMiddleware::load_user_agents(&self.source)?);
129
130        let domain_cache = Cache::builder()
131            .time_to_live(Duration::from_secs(30 * 60)) // 30 minutes
132            .build();
133
134        for entry in self.per_domain.iter() {
135            let domain = entry.key().clone();
136            let source = entry.value().clone();
137            let pool = Arc::new(UserAgentMiddleware::load_user_agents(&source)?);
138            domain_cache.insert(domain, pool);
139        }
140
141        let middleware = UserAgentMiddleware {
142            strategy: self.strategy,
143            fallback_user_agent: self.fallback_user_agent,
144            domain_cache,
145            default_pool,
146            current_index: AtomicUsize::new(0),
147        };
148
149        info!(
150            "Initializing UserAgentMiddleware with config: {:?}",
151            middleware
152        );
153
154        Ok(middleware)
155    }
156}
157
158pub struct UserAgentMiddleware {
159    strategy: UserAgentRotationStrategy,
160    fallback_user_agent: Option<String>,
161    domain_cache: Cache<String, Arc<Vec<UserAgentProfile>>>,
162    default_pool: Arc<Vec<UserAgentProfile>>,
163    current_index: AtomicUsize,
164}
165
166impl Debug for UserAgentMiddleware {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        f.debug_struct("UserAgentMiddleware")
169            .field("strategy", &self.strategy)
170            .field("fallback_user_agent", &self.fallback_user_agent)
171            .field(
172                "domain_cache",
173                &format!("Cache({})", self.domain_cache.weighted_size()),
174            )
175            .field(
176                "default_pool",
177                &format!("Pool({})", self.default_pool.len()),
178            )
179            .field("current_index", &self.current_index)
180            .finish()
181    }
182}
183
184impl UserAgentMiddleware {
185    /// Creates a new `UserAgentMiddlewareBuilder` to start building a `UserAgentMiddleware`.
186    pub fn builder() -> UserAgentMiddlewareBuilder {
187        UserAgentMiddlewareBuilder::default()
188    }
189
190    fn load_user_agents(source: &UserAgentSource) -> Result<Vec<UserAgentProfile>, SpiderError> {
191        match source {
192            UserAgentSource::List(list) => Ok(list
193                .iter()
194                .map(|ua| UserAgentProfile::from(ua.clone()))
195                .collect()),
196            UserAgentSource::File(path) => Self::load_from_file(path),
197            UserAgentSource::Builtin(builtin_list) => {
198                Ok(Self::load_builtin_user_agents(builtin_list))
199            }
200            UserAgentSource::None => Ok(Vec::new()),
201        }
202    }
203
204    fn load_from_file(path: &Path) -> Result<Vec<UserAgentProfile>, SpiderError> {
205        if !path.exists() {
206            return Err(SpiderError::IoError(std::io::Error::new(
207                std::io::ErrorKind::NotFound,
208                format!("User-agent file not found: {}", path.display()),
209            )));
210        }
211        let file = File::open(path)?;
212        let reader = BufReader::new(file);
213        let user_agents: Vec<UserAgentProfile> = reader
214            .lines()
215            .map_while(Result::ok)
216            .filter(|line| !line.trim().is_empty())
217            .map(UserAgentProfile::from)
218            .collect();
219
220        if user_agents.is_empty() {
221            warn!(
222                "User-Agent file {:?} is empty or contains no valid User-Agents.",
223                path
224            );
225        }
226        Ok(user_agents)
227    }
228
229    //TODO: provide a list of user agents from third parties
230    fn load_builtin_user_agents(list_type: &BuiltinUserAgentList) -> Vec<UserAgentProfile> {
231        match list_type {
232            BuiltinUserAgentList::Desktop => vec![
233                UserAgentProfile::from(
234                    "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
235                ),
236                UserAgentProfile::from(
237                    "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
238                ),
239                UserAgentProfile::from(
240                    "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/121.0",
241                ),
242            ],
243            BuiltinUserAgentList::Mobile => vec![
244                UserAgentProfile::from(
245                    "Mozilla/5.0 (Linux; Android 10) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.6099.193 Mobile Safari/537.36",
246                ),
247                UserAgentProfile::from(
248                    "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
249                ),
250            ],
251            BuiltinUserAgentList::RandomPopular => vec![
252                UserAgentProfile::from(
253                    "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
254                ),
255                UserAgentProfile::from(
256                    "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
257                ),
258                UserAgentProfile::from(
259                    "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/121.0",
260                ),
261                UserAgentProfile::from(
262                    "Mozilla/5.0 (Linux; Android 10) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.6099.193 Mobile Safari/537.36",
263                ),
264                UserAgentProfile::from(
265                    "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
266                ),
267            ],
268        }
269    }
270
271    fn get_user_agent(&self, domain: Option<&str>) -> Option<UserAgentProfile> {
272        let mut rng = rand::thread_rng();
273
274        let pool = domain
275            .and_then(|d| self.domain_cache.get(d))
276            .unwrap_or_else(|| self.default_pool.clone());
277
278        if pool.is_empty() {
279            debug!("User-Agent pool is empty.");
280            return self
281                .fallback_user_agent
282                .as_ref()
283                .map(|ua| UserAgentProfile::from(ua.clone()));
284        }
285
286        match self.strategy {
287            UserAgentRotationStrategy::Random => pool.choose(&mut rng).cloned(),
288            UserAgentRotationStrategy::Sequential => {
289                let current = self.current_index.fetch_add(1, Ordering::SeqCst);
290                let index = current % pool.len();
291                pool.get(index).cloned()
292            }
293        }
294    }
295}
296
297#[async_trait]
298impl<C: Send + Sync> Middleware<C> for UserAgentMiddleware {
299    fn name(&self) -> &str {
300        "UserAgentMiddleware"
301    }
302
303    async fn process_request(
304        &mut self,
305        _client: &C,
306        mut request: Request,
307    ) -> Result<MiddlewareAction<Request>, SpiderError> {
308        let domain = request.url.domain();
309        if let Some(profile) = self.get_user_agent(domain) {
310            debug!("Applying User-Agent: {}", profile.user_agent);
311            request.headers.insert(
312                USER_AGENT,
313                HeaderValue::from_str(&profile.user_agent).map_err(|e| {
314                    SpiderError::HeaderValueError(format!(
315                        "Invalid User-Agent string '{}': {}",
316                        profile.user_agent, e
317                    ))
318                })?,
319            );
320            for header in profile.headers.iter() {
321                request.headers.insert(
322                    reqwest::header::HeaderName::from_bytes(header.key().as_bytes()).map_err(
323                        |e| SpiderError::HeaderValueError(format!("Invalid header name: {}", e)),
324                    )?,
325                    HeaderValue::from_str(header.value().as_str()).map_err(|e| {
326                        SpiderError::HeaderValueError(format!(
327                            "Invalid header value for {}: {}",
328                            header.key(),
329                            e
330                        ))
331                    })?,
332                );
333            }
334        } else {
335            debug!("No User-Agent applied.");
336        }
337        Ok(MiddlewareAction::Continue(request))
338    }
339}