Skip to main content

spider_lib/middlewares/
robots_txt.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use http::header::USER_AGENT;
6use moka::future::Cache;
7use robotstxt::DefaultMatcher;
8use tracing::{debug, info, warn};
9
10use crate::downloader::SimpleHttpClient;
11use crate::error::SpiderError;
12use crate::middleware::{Middleware, MiddlewareAction};
13use crate::request::Request;
14
15/// Robots.txt middleware
16#[derive(Debug)]
17pub struct RobotsTxtMiddleware {
18    cache_ttl: Duration,
19    cache_capacity: u64,
20    request_timeout: Duration,
21    cache: Cache<String, Arc<String>>,
22}
23
24impl Default for RobotsTxtMiddleware {
25    fn default() -> Self {
26        let cache_ttl = Duration::from_secs(60 * 60 * 24);
27        let cache_capacity = 10_000;
28        let cache = Cache::builder()
29            .time_to_live(cache_ttl)
30            .max_capacity(cache_capacity)
31            .build();
32
33        let middleware = Self {
34            cache_ttl,
35            cache_capacity,
36            request_timeout: Duration::from_secs(5),
37            cache,
38        };
39        info!(
40            "Initializing RobotsTxtMiddleware with config: {:?}",
41            middleware
42        );
43        middleware
44    }
45}
46
47impl RobotsTxtMiddleware {
48    /// Creates a new `RobotsTxtMiddleware` with default settings.
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    /// Set the time-to-live for the cache.
54    pub fn cache_ttl(mut self, cache_ttl: Duration) -> Self {
55        self.cache_ttl = cache_ttl;
56        self.rebuild_cache();
57        self
58    }
59
60    /// Set the max capacity for the cache.
61    pub fn cache_capacity(mut self, cache_capacity: u64) -> Self {
62        self.cache_capacity = cache_capacity;
63        self.rebuild_cache();
64        self
65    }
66
67    /// Set the timeout for fetching robots.txt files.
68    pub fn request_timeout(mut self, request_timeout: Duration) -> Self {
69        self.request_timeout = request_timeout;
70        self
71    }
72
73    /// Rebuilds the cache with the current settings.
74    fn rebuild_cache(&mut self) {
75        self.cache = Cache::builder()
76            .time_to_live(self.cache_ttl)
77            .max_capacity(self.cache_capacity)
78            .build();
79    }
80
81    async fn fetch_robots_content<C: SimpleHttpClient>(
82        &self,
83        client: &C,
84        origin: &str,
85    ) -> Arc<String> {
86        let robots_url = format!("{}/robots.txt", origin);
87        debug!("Fetching robots.txt from: {}", robots_url);
88
89        let permissive = || Arc::new(String::new());
90
91        match client.get_text(&robots_url, self.request_timeout).await {
92            Ok((status, body)) if status.is_success() => match String::from_utf8(body.into()) {
93                Ok(text) => Arc::new(text),
94                Err(e) => {
95                    warn!("Failed to read robots.txt {}: {}", robots_url, e);
96                    permissive()
97                }
98            },
99            Ok((status, _)) => {
100                debug!(
101                    "robots.txt {} returned {} — allowing all",
102                    robots_url, status
103                );
104                permissive()
105            }
106            Err(e) => {
107                warn!("Failed to fetch robots.txt {}: {}", robots_url, e);
108                permissive()
109            }
110        }
111    }
112}
113
114#[async_trait]
115impl<C: SimpleHttpClient> Middleware<C> for RobotsTxtMiddleware {
116    fn name(&self) -> &str {
117        "RobotsTxtMiddleware"
118    }
119
120    async fn process_request(
121        &mut self,
122        client: &C,
123        request: Request,
124    ) -> Result<MiddlewareAction<Request>, SpiderError> {
125        let url = request.url.clone();
126        let origin = match url.origin().unicode_serialization() {
127            s if s == "null" => return Ok(MiddlewareAction::Continue(request)),
128            s => s,
129        };
130
131        let robots_body = match self.cache.get(&origin).await {
132            Some(body) => body,
133            None => {
134                let body = self.fetch_robots_content(client, &origin).await;
135                self.cache.insert(origin.clone(), body.clone()).await;
136                body
137            }
138        };
139
140        if let Some(user_agent) = request.headers.get(USER_AGENT) {
141            let ua = user_agent
142                .to_str()
143                .map_err(|e| SpiderError::HeaderValueError(e.to_string()))?;
144
145            let mut matcher = DefaultMatcher::default();
146            if matcher.one_agent_allowed_by_robots(robots_body.as_str(), ua, url.as_str()) {
147                return Ok(MiddlewareAction::Continue(request));
148            }
149        }
150
151        debug!("Blocked by robots.txt: {}", url);
152        Err(SpiderError::BlockedByRobotsTxt)
153    }
154}