Skip to main content

spider_lib/
builder.rs

1use crate::ConsoleWriterPipeline;
2use crate::checkpoint::Checkpoint;
3use crate::downloaders::reqwest_client::ReqwestClientDownloader;
4use crate::downloader::Downloader;
5use crate::error::SpiderError;
6use crate::middleware::Middleware;
7use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
8use crate::pipeline::Pipeline;
9use crate::scheduler::Scheduler;
10use crate::spider::Spider;
11use num_cpus;
12use std::fs;
13use std::marker::PhantomData;
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use std::time::Duration;
17use tracing::{debug, warn};
18
19use super::Crawler;
20
21/// Configuration for the crawler's concurrency settings.
22pub struct CrawlerConfig {
23    /// The maximum number of concurrent downloads.
24    pub max_concurrent_downloads: usize,
25    /// The number of workers dedicated to parsing responses.
26    pub parser_workers: usize,
27    /// The maximum number of concurrent item processing pipelines.
28    pub max_concurrent_pipelines: usize,
29}
30
31impl Default for CrawlerConfig {
32    fn default() -> Self {
33        CrawlerConfig {
34            max_concurrent_downloads: 5,
35            parser_workers: num_cpus::get(),
36            max_concurrent_pipelines: 5,
37        }
38    }
39}
40
41pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
42where
43    D: Downloader,
44{
45    crawler_config: CrawlerConfig,
46    downloader: D,
47    spider: Option<S>,
48    middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
49    item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
50    checkpoint_path: Option<PathBuf>,
51    checkpoint_interval: Option<Duration>,
52    _phantom: PhantomData<S>,
53}
54
55impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
56    fn default() -> Self {
57        Self {
58            crawler_config: CrawlerConfig::default(),
59            downloader: D::default(),
60            spider: None,
61            middlewares: Vec::new(),
62            item_pipelines: Vec::new(),
63            checkpoint_path: None,
64            checkpoint_interval: None,
65            _phantom: PhantomData,
66        }
67    }
68}
69
70impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
71    /// Creates a new `CrawlerBuilder` for a given spider.
72    pub fn new(spider: S) -> Self
73    where
74        D: Default,
75    {
76        Self {
77            spider: Some(spider),
78            ..Default::default()
79        }
80    }
81
82    /// Sets the maximum number of concurrent downloads.
83    pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
84        self.crawler_config.max_concurrent_downloads = limit;
85        self
86    }
87
88    /// Sets the maximum number of concurrent parser workers.
89    pub fn max_parser_workers(mut self, limit: usize) -> Self {
90        self.crawler_config.parser_workers = limit;
91        self
92    }
93
94    /// Sets the maximum number of concurrent pipelines.
95    pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
96        self.crawler_config.max_concurrent_pipelines = limit;
97        self
98    }
99
100    /// Sets a custom downloader for the crawler.
101    pub fn downloader(mut self, downloader: D) -> Self {
102        self.downloader = downloader;
103        self
104    }
105
106    /// Adds a middleware to the crawler.
107    pub fn add_middleware<M>(mut self, middleware: M) -> Self
108    where
109        D: Downloader,
110        M: Middleware<D::Client> + Send + Sync + 'static,
111    {
112        self.middlewares.push(Box::new(middleware));
113        self
114    }
115
116    /// Adds an item pipeline to the crawler.
117    pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
118    where
119        P: Pipeline<S::Item> + 'static,
120    {
121        self.item_pipelines.push(Box::new(pipeline));
122        self
123    }
124
125    /// Enables checkpointing and sets the path for the checkpoint file.
126    pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
127        self.checkpoint_path = Some(path.as_ref().to_path_buf());
128        self
129    }
130
131    /// Sets the interval for periodic checkpointing.
132    pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
133        self.checkpoint_interval = Some(interval);
134        self
135    }
136
137    /// Builds the `Crawler` instance.
138    pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
139    where
140        D: Downloader + Send + Sync + 'static,
141        D::Client: Send + Sync,
142    {
143        if self.item_pipelines.is_empty() {
144            self = self.add_pipeline(ConsoleWriterPipeline::new());
145        }
146
147        let spider = self.spider.take().ok_or_else(|| {
148            SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
149        })?;
150
151        if self.crawler_config.max_concurrent_downloads == 0 {
152            return Err(SpiderError::ConfigurationError(
153                "max_concurrent_downloads must be greater than 0.".to_string(),
154            ));
155        }
156        if self.crawler_config.parser_workers == 0 {
157            return Err(SpiderError::ConfigurationError(
158                "parser_workers must be greater than 0.".to_string(),
159            ));
160        }
161
162        let mut initial_scheduler_state = None;
163        let mut loaded_pipelines_state = None;
164
165        if let Some(path) = &self.checkpoint_path {
166            debug!("Attempting to load checkpoint from {:?}", path);
167            match fs::read(path) {
168                Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
169                    Ok(checkpoint) => {
170                        initial_scheduler_state = Some(checkpoint.scheduler);
171                        loaded_pipelines_state = Some(checkpoint.pipelines);
172                    }
173                    Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
174                },
175                Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
176            }
177        }
178
179        let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state.clone());
180
181        // Restore pipeline states now that pipelines are built
182        if let Some(pipeline_states) = loaded_pipelines_state {
183            for (name, state) in pipeline_states {
184                if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
185                    pipeline.restore_state(state).await?;
186                } else {
187                    warn!("Checkpoint contains state for unknown pipeline: {}", name);
188                }
189            }
190        }
191
192        let has_user_agent_middleware = self
193            .middlewares
194            .iter()
195            .any(|m| m.name() == "UserAgentMiddleware");
196
197        if !has_user_agent_middleware {
198            let pkg_name = env!("CARGO_PKG_NAME");
199            let pkg_version = env!("CARGO_PKG_VERSION");
200            let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
201
202            let default_user_agent_mw = UserAgentMiddleware::builder()
203                .source(UserAgentSource::List(vec![default_user_agent.clone()]))
204                .fallback_user_agent(default_user_agent)
205                .build()?;
206            self.middlewares.insert(0, Box::new(default_user_agent_mw));
207        }
208
209        let downloader_arc = Arc::new(self.downloader);
210
211        let crawler = Crawler::new(
212            scheduler_arc,
213            req_rx,
214            downloader_arc,
215            self.middlewares,
216            spider,
217            self.item_pipelines,
218            self.crawler_config.max_concurrent_downloads,
219            self.crawler_config.parser_workers,
220            self.crawler_config.max_concurrent_pipelines,
221            self.checkpoint_path.take(),
222            self.checkpoint_interval,
223        );
224
225        Ok(crawler)
226    }
227}